Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 118 additions & 4 deletions mpisppy/spbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,10 +642,124 @@ def spcomm(self, value):


def allreduce_or(self, val):
local_val = np.array([val], dtype='int8')
global_val = np.zeros(1, dtype='int8')
self.mpicomm.Allreduce(local_val, global_val, op=MPI.LOR)
if global_val[0] > 0:
# ====== DEBUG: LOR_bug instrumentation ======
# Yields per call (on cyl_rk == 0 of self.mpicomm) the full picture
# needed to localize an Allreduce(LOR) returning nonzero when every
# rank intends a zero. Probes four axes:
# 1. comm membership (world ranks participating, size, uniqueness)
# 2. data going in (Allgather of every rank's local_val)
# 3. reduction sanity (SUM/MAX/LOR + a rank-sum check whose
# expected value is n*(n-1)/2)
# 4. consistency (compare Allgather sum to Allreduce SUM
# to tell "input was wrong" from
# "Allreduce is wrong")
# Remove before merging to main. See PR description for hypothesis tree.
import os
import socket
import sys
sz = self.mpicomm.Get_size()
cyl_rk = self.mpicomm.Get_rank()
world_rk = MPI.COMM_WORLD.Get_rank()
host = socket.gethostname()
pid = os.getpid()

local_int = 1 if val else 0
local_int32 = np.array([local_int], dtype='int32')
local_int8 = np.array([local_int], dtype='int8')

# (3) Reductions — three ops in parallel, plus a rank-sum sanity check.
sum_out = np.zeros(1, dtype='int32')
self.mpicomm.Allreduce(local_int32, sum_out, op=MPI.SUM)

max_out = np.zeros(1, dtype='int32')
self.mpicomm.Allreduce(local_int32, max_out, op=MPI.MAX)

lor_out = np.zeros(1, dtype='int8')
self.mpicomm.Allreduce(local_int8, lor_out, op=MPI.LOR)

rank_in = np.array([cyl_rk], dtype='int32')
rank_out = np.zeros(1, dtype='int32')
self.mpicomm.Allreduce(rank_in, rank_out, op=MPI.SUM)
expected_rank_sum = sz * (sz - 1) // 2

# (1) + (2) Allgather of (world_rk, cyl_rk, local_int) so we see
# exactly which ranks participated and what each one contributed.
report = np.array([world_rk, cyl_rk, local_int], dtype='int32')
all_reports = np.zeros(3 * sz, dtype='int32')
self.mpicomm.Allgather(report, all_reports)

# Track a per-instance call counter so logs are correlatable.
self._lor_diag_count = getattr(self, "_lor_diag_count", 0) + 1
call_n = self._lor_diag_count

if cyl_rk == 0:
rows = all_reports.reshape(sz, 3)
wr = rows[:, 0].tolist()
nonzero_rows = rows[rows[:, 2] != 0]
gather_sum = int(rows[:, 2].sum())
cls = type(self).__name__
try:
comm_name = self.mpicomm.Get_name()
except Exception:
comm_name = "<unknown>"
print(
f"[LOR_bug call={call_n} cls={cls} "
f"world_rk={world_rk} host={host} pid={pid}] "
f"mpicomm size={sz} name={comm_name!r}",
flush=True,
)
print(
f" world_ranks: min={min(wr)} max={max(wr)} "
f"count={len(wr)} unique={len(set(wr))}",
flush=True,
)
print(
f" reductions: sum={int(sum_out[0])} max={int(max_out[0])} "
f"lor={int(lor_out[0])} rank_sum={int(rank_out[0])} "
f"expected_rank_sum={expected_rank_sum}",
flush=True,
)
print(
f" gather: gather_sum={gather_sum} "
f"nonzero_reports={len(nonzero_rows)}",
flush=True,
)
# "Bad" = invariant-violating, NOT just "nonzero result." A
# legitimate shutdown signal returns sum=lor=1 with
# gather_sum=1 (consistent), which is fine. The real bug
# signature is gather_sum disagreeing with the Allreduce SUM
# (the reducer lying), or the rank-sum sanity check failing
# (SUM broken on this comm), or duplicate world ranks
# (group membership corrupted), or some rank packing >1
# (non-boolean input — only possible under memory aliasing).
bad = (
int(rank_out[0]) != expected_rank_sum
or int(sum_out[0]) != gather_sum
or len(set(wr)) != len(wr)
or int(max_out[0]) > 1
)
if bad:
limit = min(64, len(nonzero_rows))
for w, c, v in nonzero_rows[:limit].tolist():
print(
f" nonzero: world_rk={w} cyl_rk={c} local_val={v}",
flush=True,
)
if len(nonzero_rows) > limit:
print(
f" (... {len(nonzero_rows) - limit} more nonzero rows truncated ...)",
flush=True,
)
# Also dump the full world-rank list once so we can see exactly
# who is participating in this comm.
print(
f" ALL world_ranks: {wr}",
flush=True,
)
sys.stdout.flush()
# ====== END DEBUG ======

if lor_out[0] > 0:
return True
else:
return False
Expand Down
Loading