Skip to content

Commit c3da9c7

Browse files
committed
Auto-pick non-fork mp_ctx when sampling with JAX
JAX is not fork-safe and can deadlock under multiprocessing's fork start method. When a JAX backend is detected and the user has not specified an `mp_ctx`, fall back to `forkserver` (or `spawn`). Warn instead of switch if the user explicitly asked for fork. Applies to both `pm.sample` and `pm.sample_smc`.
1 parent d50dad1 commit c3da9c7

4 files changed

Lines changed: 65 additions & 6 deletions

File tree

pymc/sampling/mcmc.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -841,7 +841,10 @@ def sample(
841841
if chains is None:
842842
chains = max(2, cores)
843843

844-
mp_ctx = _initialize_multiprocessing_context(mp_ctx, quiet=quiet)
844+
compile_kwargs = resolve_backend_compile_kwargs(backend, compile_kwargs)
845+
mp_ctx = _initialize_multiprocessing_context(
846+
mp_ctx, mode=compile_kwargs.get("mode"), quiet=quiet
847+
)
845848
joined_blas_limiter, cores, num_blas_cores_per_worker = setup_cores_blas_cores(
846849
blas_cores, chains, cores, mp_ctx
847850
)
@@ -890,8 +893,6 @@ def sample(
890893
)
891894
)
892895

893-
compile_kwargs = resolve_backend_compile_kwargs(backend, compile_kwargs)
894-
895896
if nuts_sampler is None:
896897
# Nutpie must take all the variables and can only compile to Numba or JAX
897898
can_use_nutpie = (

pymc/sampling/parallel.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import platform
2020
import time
2121
import traceback
22+
import warnings
2223

2324
from collections import namedtuple
2425
from collections.abc import Sequence
@@ -28,6 +29,8 @@
2829
import cloudpickle
2930
import numpy as np
3031

32+
from pytensor.compile import get_mode
33+
from pytensor.link.jax.linker import JAXLinker
3134
from rich.theme import Theme
3235
from threadpoolctl import threadpool_limits
3336

@@ -78,8 +81,14 @@ def rebuild_exc(exc, tb):
7881

7982

8083
def _initialize_multiprocessing_context(
81-
mp_ctx: str | multiprocessing.context.BaseContext | None, quiet: bool = False
84+
mp_ctx: str | multiprocessing.context.BaseContext | None,
85+
*,
86+
mode=None,
87+
quiet: bool = False,
8288
) -> multiprocessing.context.BaseContext:
89+
user_specified = mp_ctx is not None
90+
jax_mode = mode is not None and isinstance(get_mode(mode).linker, JAXLinker)
91+
8392
if mp_ctx is None or isinstance(mp_ctx, str):
8493
# Closes issue https://github.com/pymc-devs/pymc/issues/3849
8594
# Related issue https://github.com/pymc-devs/pymc/issues/5339
@@ -96,6 +105,21 @@ def _initialize_multiprocessing_context(
96105

97106
mp_ctx = multiprocessing.get_context(mp_ctx)
98107

108+
if jax_mode and mp_ctx.get_start_method() == "fork":
109+
if user_specified:
110+
warnings.warn(
111+
"Using a JAX backend with multiprocessing start method 'fork' is unsafe "
112+
"and may deadlock. Consider passing `mp_ctx='forkserver'` or `mp_ctx='spawn'`.",
113+
UserWarning,
114+
stacklevel=2,
115+
)
116+
else:
117+
# JAX is not fork-safe: pick a non-fork default when user didn't specify.
118+
new_method = (
119+
"forkserver" if "forkserver" in multiprocessing.get_all_start_methods() else "spawn"
120+
)
121+
mp_ctx = multiprocessing.get_context(new_method)
122+
99123
return mp_ctx
100124

101125

pymc/smc/sampling.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,15 +179,16 @@ def sample_smc(
179179
else:
180180
cores = min(chains, cores)
181181

182-
kernel_kwargs["compile_kwargs"] = resolve_backend_compile_kwargs(backend, compile_kwargs)
182+
compile_kwargs = resolve_backend_compile_kwargs(backend, compile_kwargs)
183+
kernel_kwargs["compile_kwargs"] = compile_kwargs
183184

184185
random_seed = _get_seeds_per_chain(random_state=random_seed, chains=chains)
185186

186187
model = modelcontext(model)
187188

188189
logger.info("Initializing SMC sampler...")
189190

190-
mp_ctx = _initialize_multiprocessing_context(mp_ctx)
191+
mp_ctx = _initialize_multiprocessing_context(mp_ctx, mode=compile_kwargs.get("mode"))
191192
joined_blas_limiter, cores, num_blas_cores_per_worker = setup_cores_blas_cores(
192193
blas_cores, chains, cores, mp_ctx
193194
)

tests/sampling/test_parallel.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,39 @@ def test_context():
4141
pm.sample(tune=2, draws=2, chains=2, cores=2, mp_ctx=ctx)
4242

4343

44+
@pytest.mark.skipif(
45+
platform.system() == "Darwin" and platform.processor() != "arm",
46+
reason="Default mp_ctx on non-ARM macOS is already forkserver",
47+
)
48+
class TestMpCtxJaxSwitch:
49+
def test_switches_default_away_from_fork_under_jax(self):
50+
with warnings.catch_warnings(record=True) as w:
51+
warnings.simplefilter("always")
52+
ctx = ps._initialize_multiprocessing_context(None, mode="JAX")
53+
assert ctx.get_start_method() != "fork"
54+
assert not any("JAX backend" in str(x.message) for x in w)
55+
56+
def test_warns_when_user_explicitly_picks_fork_under_jax(self):
57+
with pytest.warns(UserWarning, match="JAX backend"):
58+
ctx = ps._initialize_multiprocessing_context("fork", mode="JAX")
59+
assert ctx.get_start_method() == "fork"
60+
61+
def test_leaves_non_jax_default_alone(self):
62+
with warnings.catch_warnings(record=True) as w:
63+
warnings.simplefilter("always")
64+
ctx = ps._initialize_multiprocessing_context(None, mode="NUMBA")
65+
assert ctx.get_start_method() == "fork"
66+
assert not any("JAX backend" in str(x.message) for x in w)
67+
68+
@pytest.mark.parametrize("explicit", ["spawn", "forkserver"])
69+
def test_no_warn_when_user_picks_safe_method(self, explicit):
70+
with warnings.catch_warnings(record=True) as w:
71+
warnings.simplefilter("always")
72+
ctx = ps._initialize_multiprocessing_context(explicit, mode="JAX")
73+
assert ctx.get_start_method() == explicit
74+
assert not any("JAX backend" in str(x.message) for x in w)
75+
76+
4477
class NoUnpickle:
4578
def __getstate__(self):
4679
return self.__dict__.copy()

0 commit comments

Comments
 (0)