Skip to content

Commit 6b30fd4

Browse files
committed
Merge fix/lazy-mpi-init into ngmix_v2.0 (#744)
Gate mpi4py import on MPI launcher environment so bare shapepipe_run inside an srun shell no longer aborts in MPI_Init. Same commits as PR #747; folded here so the fix ships with the ngmix v2.0 line.
2 parents 6bae2a4 + dab2d5c commit 6b30fd4

2 files changed

Lines changed: 76 additions & 6 deletions

File tree

src/shapepipe/run.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
77
"""
88

9+
import os
910
import sys
1011
from datetime import datetime
1112
from importlib.metadata import requires
@@ -22,12 +23,27 @@
2223
from shapepipe.pipeline.job_handler import JobHandler
2324
from shapepipe.pipeline.mpi_run import split_mpi_jobs, submit_mpi_jobs
2425

25-
try:
26-
from mpi4py import MPI
27-
except ImportError: # pragma: no cover
28-
import_mpi = False
26+
# Importing mpi4py initializes MPI immediately, which aborts the whole
27+
# process when no MPI launcher is available — e.g. inside an
28+
# ``srun``-launched shell on a SLURM cluster, where Open MPI detects the
29+
# SLURM step environment, expects a PMI server that srun never started,
30+
# and calls MPI_Abort before even ``shapepipe_run -h`` can print (#744).
31+
# Only import (and hence initialize) MPI when a launcher environment is
32+
# actually present: ``mpirun``/``orterun`` set OMPI_COMM_WORLD_SIZE,
33+
# ``srun --mpi=pmi2`` sets PMI_RANK and ``srun --mpi=pmix`` sets
34+
# PMIX_RANK. A bare ``shapepipe_run`` (login node, compute-node shell,
35+
# container) runs in SMP mode without ever touching MPI.
36+
_MPI_LAUNCHER_VARS = ("OMPI_COMM_WORLD_SIZE", "PMI_RANK", "PMIX_RANK")
37+
38+
if any(var in os.environ for var in _MPI_LAUNCHER_VARS):
39+
try:
40+
from mpi4py import MPI
41+
except ImportError: # pragma: no cover
42+
import_mpi = False
43+
else:
44+
import_mpi = True
2945
else:
30-
import_mpi = True
46+
import_mpi = False
3147

3248

3349
class ShapePipe:
@@ -178,7 +194,7 @@ def _check_dependencies(self):
178194
module_dep = self._get_module_depends("depends") + __installs__
179195
module_exe = self._get_module_depends("executes")
180196

181-
module_dep += ["mpi4py"] if import_mpi else module_dep
197+
module_dep += ["mpi4py"] if import_mpi else []
182198

183199
exe_to_module = {
184200
exe: module

src/shapepipe/tests/test_run.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
"""UNIT TESTS FOR RUN.
2+
3+
This module contains unit tests for the shapepipe.run module, in
4+
particular the MPI-launcher gating of the mpi4py import (#744): a bare
5+
``shapepipe_run`` must never initialize MPI, otherwise the whole process
6+
aborts inside an ``srun``-launched shell whose Open MPI lacks SLURM PMI
7+
support.
8+
9+
:Author: Claude (on behalf of Cail Daley) <cail.daley@cea.fr>
10+
11+
"""
12+
13+
import os
14+
import subprocess
15+
import sys
16+
17+
import pytest
18+
19+
SNIPPET = "import shapepipe.run as r; print(r.import_mpi)"
20+
21+
# Env vars that either mark an MPI launcher (the gate) or make Open MPI
22+
# believe it was direct-launched by srun (the failure mode under test).
23+
_SCRUBBED_PREFIXES = ("OMPI_", "PMI_", "PMIX_", "SLURM_")
24+
25+
26+
def _import_mpi_flag(extra_env):
27+
"""Report shapepipe.run.import_mpi in a subprocess with a clean env."""
28+
env = {
29+
key: value
30+
for key, value in os.environ.items()
31+
if not key.startswith(_SCRUBBED_PREFIXES)
32+
}
33+
env.update(extra_env)
34+
result = subprocess.run(
35+
[sys.executable, "-c", SNIPPET],
36+
env=env,
37+
capture_output=True,
38+
text=True,
39+
)
40+
assert result.returncode == 0, (
41+
f"subprocess failed (exit {result.returncode}): {result.stderr}"
42+
)
43+
return result.stdout.strip()
44+
45+
46+
def test_bare_launch_skips_mpi():
47+
"""A bare launch (no MPI launcher env) must not import/init MPI."""
48+
assert _import_mpi_flag({}) == "False"
49+
50+
51+
def test_mpirun_launch_imports_mpi():
52+
"""An mpirun-style env (OMPI_COMM_WORLD_SIZE) must import MPI."""
53+
pytest.importorskip("mpi4py")
54+
assert _import_mpi_flag({"OMPI_COMM_WORLD_SIZE": "1"}) == "True"

0 commit comments

Comments
 (0)