Skip to content

Commit 22fa622

Browse files
committed
Auto-pick non-fork mp_ctx when sampling with JAX
JAX is not fork-safe and can deadlock under multiprocessing's fork start method. When a JAX backend is detected and the user has not specified an `mp_ctx`, fall back to `forkserver` (or `spawn`). Warn instead of switch if the user explicitly asked for fork. Applies to both `pm.sample` and `pm.sample_smc`.
1 parent df97bec commit 22fa622

4 files changed

Lines changed: 68 additions & 6 deletions

File tree

pymc/sampling/mcmc.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -827,7 +827,10 @@ def sample(
827827
if chains is None:
828828
chains = max(2, cores)
829829

830-
mp_ctx = _initialize_multiprocessing_context(mp_ctx, quiet=quiet)
830+
compile_kwargs = resolve_backend_compile_kwargs(backend, compile_kwargs)
831+
mp_ctx = _initialize_multiprocessing_context(
832+
mp_ctx, mode=compile_kwargs.get("mode"), quiet=quiet
833+
)
831834
joined_blas_limiter, cores, num_blas_cores_per_worker = setup_cores_blas_cores(
832835
blas_cores, chains, cores, mp_ctx
833836
)
@@ -870,8 +873,6 @@ def sample(
870873
)
871874
)
872875

873-
compile_kwargs = resolve_backend_compile_kwargs(backend, compile_kwargs)
874-
875876
if nuts_sampler is None:
876877
# Try to use nutpie by default if no setting is clearly at odds.
877878
# Requires all model variables, numba or jax preference,

pymc/sampling/parallel.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import platform
2020
import time
2121
import traceback
22+
import warnings
2223

2324
from collections import namedtuple
2425
from collections.abc import Sequence
@@ -28,6 +29,8 @@
2829
import cloudpickle
2930
import numpy as np
3031

32+
from pytensor.compile import get_mode
33+
from pytensor.link.jax.linker import JAXLinker
3134
from rich.theme import Theme
3235
from threadpoolctl import threadpool_limits
3336

@@ -78,8 +81,14 @@ def rebuild_exc(exc, tb):
7881

7982

8083
def _initialize_multiprocessing_context(
81-
mp_ctx: str | multiprocessing.context.BaseContext | None, quiet: bool = False
84+
mp_ctx: str | multiprocessing.context.BaseContext | None,
85+
*,
86+
mode=None,
87+
quiet: bool = False,
8288
) -> multiprocessing.context.BaseContext:
89+
user_specified = mp_ctx is not None
90+
jax_mode = mode is not None and isinstance(get_mode(mode).linker, JAXLinker)
91+
8392
if mp_ctx is None or isinstance(mp_ctx, str):
8493
# Closes issue https://github.com/pymc-devs/pymc/issues/3849
8594
# Related issue https://github.com/pymc-devs/pymc/issues/5339
@@ -96,6 +105,21 @@ def _initialize_multiprocessing_context(
96105

97106
mp_ctx = multiprocessing.get_context(mp_ctx)
98107

108+
if jax_mode and mp_ctx.get_start_method() == "fork":
109+
if user_specified:
110+
warnings.warn(
111+
"Using a JAX backend with multiprocessing start method 'fork' is unsafe "
112+
"and may deadlock. Consider passing `mp_ctx='forkserver'` or `mp_ctx='spawn'`.",
113+
UserWarning,
114+
stacklevel=2,
115+
)
116+
else:
117+
# JAX is not fork-safe: pick a non-fork default when user didn't specify.
118+
new_method = (
119+
"forkserver" if "forkserver" in multiprocessing.get_all_start_methods() else "spawn"
120+
)
121+
mp_ctx = multiprocessing.get_context(new_method)
122+
99123
return mp_ctx
100124

101125

pymc/smc/sampling.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,15 +179,16 @@ def sample_smc(
179179
else:
180180
cores = min(chains, cores)
181181

182-
kernel_kwargs["compile_kwargs"] = resolve_backend_compile_kwargs(backend, compile_kwargs)
182+
compile_kwargs = resolve_backend_compile_kwargs(backend, compile_kwargs)
183+
kernel_kwargs["compile_kwargs"] = compile_kwargs
183184

184185
random_seed = _get_seeds_per_chain(random_state=random_seed, chains=chains)
185186

186187
model = modelcontext(model)
187188

188189
logger.info("Initializing SMC sampler...")
189190

190-
mp_ctx = _initialize_multiprocessing_context(mp_ctx)
191+
mp_ctx = _initialize_multiprocessing_context(mp_ctx, mode=compile_kwargs.get("mode"))
191192
joined_blas_limiter, cores, num_blas_cores_per_worker = setup_cores_blas_cores(
192193
blas_cores, chains, cores, mp_ctx
193194
)

tests/sampling/test_parallel.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,42 @@ def test_context():
4141
pm.sample(tune=2, draws=2, chains=2, cores=2, mp_ctx=ctx)
4242

4343

44+
class TestMpCtxJaxSwitch:
45+
def test_switches_default_away_from_fork_under_jax(self):
46+
if ps._initialize_multiprocessing_context(None).get_start_method() != "fork":
47+
pytest.skip("platform default is not fork")
48+
with warnings.catch_warnings(record=True) as w:
49+
warnings.simplefilter("always")
50+
ctx = ps._initialize_multiprocessing_context(None, mode="JAX")
51+
assert ctx.get_start_method() != "fork"
52+
assert not any("JAX backend" in str(x.message) for x in w)
53+
54+
def test_warns_when_user_explicitly_picks_fork_under_jax(self):
55+
if "fork" not in multiprocessing.get_all_start_methods():
56+
pytest.skip("fork start method not available on this platform")
57+
with pytest.warns(UserWarning, match="JAX backend"):
58+
ctx = ps._initialize_multiprocessing_context("fork", mode="JAX")
59+
assert ctx.get_start_method() == "fork"
60+
61+
def test_leaves_non_jax_default_alone(self):
62+
expected = ps._initialize_multiprocessing_context(None).get_start_method()
63+
with warnings.catch_warnings(record=True) as w:
64+
warnings.simplefilter("always")
65+
ctx = ps._initialize_multiprocessing_context(None, mode="NUMBA")
66+
assert ctx.get_start_method() == expected
67+
assert not any("JAX backend" in str(x.message) for x in w)
68+
69+
@pytest.mark.parametrize("explicit", ["spawn", "forkserver"])
70+
def test_no_warn_when_user_picks_safe_method(self, explicit):
71+
if explicit not in multiprocessing.get_all_start_methods():
72+
pytest.skip(f"{explicit} start method not available on this platform")
73+
with warnings.catch_warnings(record=True) as w:
74+
warnings.simplefilter("always")
75+
ctx = ps._initialize_multiprocessing_context(explicit, mode="JAX")
76+
assert ctx.get_start_method() == explicit
77+
assert not any("JAX backend" in str(x.message) for x in w)
78+
79+
4480
class NoUnpickle:
4581
def __getstate__(self):
4682
return self.__dict__.copy()

0 commit comments

Comments
 (0)