Skip to content

Commit 4604131

Browse files
wanghan-iapcmHan Wang
andauthored
fix(pt_expt): fail-fast on .pt2 GNN inference without LAMMPS atom-map (#5450)
## Summary - Surface a previously-silent corruption / CUDA index assert in LAMMPS `.pt2` inference for message-passing models (DPA2, DPA3, hybrids over those) when the LAMMPS atom-map is not enabled. Previously the C++ side fell into an identity-mapping fallback (`DeepPotPTExpt.cc:374-384`) whose values are wrong for ghost slots; the model's `_exchange_ghosts` (`deepmd/dpmodel/descriptor/repformers.py`) then performed `take_along_axis(g1[1, nloc, dim], mapping_tiled)` with out-of-bounds gather indices for ghosts — CUDA index assert in the user's DPA4 report, undefined CPU output otherwise. - Add a `has_message_passing` field to .pt2 metadata (mirrors the descriptor's `has_message_passing()` API: true for DPA2/DPA3/hybrids over those; false for se_e2_a/DPA1/etc.). Gate the fail-fast in `DeepPotPTExpt::compute_inner` and `DeepSpinPTExpt::compute_inner` on it. Non-GNN models retain their previous behaviour. - Two error messages target the two distinct unsupported configurations: - **Single-rank without atom-map**: "Single-rank LAMMPS .pt2 inference requires `atom_modify map yes`…" - **Multi-rank without a with-comm artifact**: "Multi-rank LAMMPS .pt2 inference requires the model to be exported with `use_loc_mapping=False`…" - Refined predicate: `has_message_passing_ && !use_with_comm && !atom_map_present && nghost > 0`. The `nghost > 0` guard skips NoPbc and isolated-cluster cases where identity over `[0, nloc)` is trivially correct. ### Four-cell coverage matrix in `test_lammps_dpa3_pt2.py` | Cell | `use_loc_mapping` | atom-map | nprocs | Path | Test | |---|---|---|---|---|---| | A | True (regular only) | yes | 1 | regular w/ correct mapping | `test_pair_deepmd` *(existing)* | | B | True | no | 1 | **fail fast** (single-rank msg) | `test_pair_deepmd_no_atom_map_fails_fast` *(new)* | | B-mr| True | any | >1 | **fail fast** (multi-rank msg) | `test_pair_deepmd_mpi_no_with_comm_fails_fast` *(new, subprocess)* | | C | False (regular + with-comm) | yes | 1 | regular w/ atom-map | `test_pair_deepmd_with_comm` *(new)* | | C-mr| False | any | >1 | with-comm (`border_op`) | `test_pair_deepmd_mpi_dpa3` *(existing)* | | D | False | no | 1 | **fail fast** (single-rank PBC can't drive border_op) | `test_pair_deepmd_with_comm_no_atom_map_fails_fast` *(new)* | | D-mr| False | no | >1 | with-comm (mapping-free) | `test_pair_deepmd_mpi_no_atom_map` *(new, subprocess)* | ### Investigation note (resolves an earlier mystery) `test_deeppot_dpa_ptexpt.cc` is misleadingly named — despite the `Dpa` prefix it loads `deeppot_dpa1.pt2` (DPA1, non-message-passing). Its regular `.pt2` graph never consumes `mapping` for ghost gather, so the identity fallback was trivially safe and the test passed without explicit `inlist.mapping`. The genuinely-DPA2 ctest is `test_deeppot_dpa2_ptexpt.cc` (different file), which already explicitly sets `inlist.mapping = mapping.data();` on all `cpu_lmp_nlist*` paths. **No C++ ctest fixtures need editing in this PR** — the metadata-gated fail-fast correctly skips DPA1. ### Backward compatibility `has_message_passing_` defaults to **false** in C++ when the metadata field is missing — so pre-PR .pt2 archives retain their previous behaviour. Non-GNN pre-PR archives continue to work; GNN pre-PR archives must be regenerated to opt into the fail-fast guard. In-tree fixtures are generated by `gen_*.py` at CI time, which always writes the new field. ## Test plan - [x] Local C++ ctest `*PtExpt*` filter: **160 / 160 PASSED** (270 s) against freshly-regenerated `.pt2` fixtures. - [ ] CI runs the negative cells (B / B-mr / D) — they exercise the throw and verify the error-message substrings. The pytest assertions use `pytest.raises(Exception, match=r\"atom_modify map yes\")` and stdout/stderr substring `use_loc_mapping=False`; if LAMMPS wraps the exception with a prefix/suffix differently than expected, the match may need adjustment. - [ ] CI cell D-mr (`test_pair_deepmd_mpi_no_atom_map`) verifies the with-comm artifact handles ghosts via `border_op` without consuming the mapping tensor. ## Known limitations - Multi-rank with `use_loc_mapping=True` is permanently unsupported by this fix — the fail-fast surfaces it clearly, no path forward without re-export. - Single-rank PBC + with-comm artifact + no atom-map (cell D) could be made to work via a synthesized self-mirror `comm_dict`; deferred to a follow-up. - `MPI_Comm_size` is not used as the multi-rank predicate because `api_cc` does not link MPI directly; `lmp_list.nswap > 0` serves as the proxy (equivalent for all current LAMMPS configurations). - The pre-PR DPA3 `use_loc_mapping=True` archives lacking the new metadata field continue to exhibit the silent-corruption bug — users must regenerate. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * CLI flag to disable atom-ID→local-index mapping for test runs; generator now produces a single-artifact spin model variant; APIs allow callers to declare MPI rank count for neighbor lists. * **Bug Fixes** * Serialized metadata now records a message-passing capability flag; runtime enforces compatibility and surfaces clear fail-fast errors when required atom-mapping or artifacts are missing. * **Tests** * Expanded coverage for message-passing variants, atom-map on/off scenarios, single- vs multi-rank MPI cases, and related fail-fast behaviors. <!-- review_stack_entry_start --> [![Review Change Stack](https://storage.googleapis.com/coderabbit_public_assets/review-stack-in-coderabbit-ui.svg)](https://app.coderabbit.ai/change-stack/deepmodeling/deepmd-kit/pull/5450?utm_source=github_walkthrough&utm_medium=github&utm_campaign=change_stack) <!-- review_stack_entry_end --> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Han Wang <wang_han@iapcm.ac.cn>
1 parent d3f08f3 commit 4604131

16 files changed

Lines changed: 616 additions & 26 deletions

deepmd/pt_expt/utils/serialization.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,41 @@ def _collect_metadata(model: torch.nn.Module, is_spin: bool = False) -> dict:
442442
# (per-layer ghost-feature MPI exchange via deepmd_export::border_op).
443443
# The C++ DeepPotPTExpt / DeepSpinPTExpt loaders branch on this flag.
444444
meta["has_comm_artifact"] = _needs_with_comm_artifact(model)
445+
446+
# Whether the model's regular .pt2 graph consumes the ``mapping``
447+
# tensor to gather per-layer ghost-atom features from local atoms.
448+
# Mirrors the descriptor's ``has_message_passing()`` API: True for
449+
# any message-passing descriptor (DPA2, DPA3, hybrids over those);
450+
# False for non-message-passing descriptors (se_e2_a, DPA1, etc.).
451+
# The C++ side gates its fail-fast on this — an absent mapping is
452+
# fatal only for models that would silently corrupt ghost features
453+
# otherwise.
454+
#
455+
# Lookup order: model -> atomic_model -> descriptor. Going through
456+
# ``atomic_model.has_message_passing()`` is important for composite
457+
# atomic models (e.g. ``LinearAtomicModel`` in DP-ZBL) which don't
458+
# expose a single ``.descriptor`` but do aggregate the flag across
459+
# their sub-models. ``descriptor.has_message_passing()`` is the
460+
# fallback for any future wrapper that lacks the higher-level
461+
# methods.
462+
def _probe_has_message_passing(obj: object) -> bool | None:
463+
if obj is None or not hasattr(obj, "has_message_passing"):
464+
return None
465+
try:
466+
return bool(obj.has_message_passing())
467+
except (AttributeError, NotImplementedError):
468+
return None
469+
470+
result: bool | None = None
471+
for obj in (
472+
model,
473+
getattr(model, "atomic_model", None),
474+
getattr(getattr(model, "atomic_model", None), "descriptor", None),
475+
):
476+
result = _probe_has_message_passing(obj)
477+
if result is not None:
478+
break
479+
meta["has_message_passing"] = result if result is not None else False
445480
return meta
446481

447482

source/api_c/include/c_api.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ extern DP_Nlist* DP_NewNlist(int inum_,
5252
* each swap.
5353
* @param[in] world Pointer to the MPI communicator or similar communication
5454
* world used for the operation.
55+
* @param[in] nprocs Number of MPI ranks (1 = single-rank). Used by
56+
* ``DeepPotPTExpt`` / ``DeepSpinPTExpt`` to choose between the regular
57+
* and with-comm artifacts. Defaults to 1 if not supplied.
5558
* @returns A pointer to the initialized neighbor list with communication
5659
* capabilities.
5760
*/
@@ -66,7 +69,8 @@ extern DP_Nlist* DP_NewNlist_comm(int inum_,
6669
int** sendlist,
6770
int* sendproc,
6871
int* recvproc,
69-
void* world);
72+
void* world,
73+
int nprocs);
7074

7175
/**
7276
* @brief Set mask for a neighbor list.

source/api_c/include/deepmd.hpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -831,7 +831,8 @@ struct InputNlist {
831831
int** sendlist,
832832
int* sendproc,
833833
int* recvproc,
834-
void* world)
834+
void* world,
835+
int nprocs = 1)
835836
: inum(inum_),
836837
ilist(ilist_),
837838
numneigh(numneigh_),
@@ -847,7 +848,8 @@ struct InputNlist {
847848
sendlist,
848849
sendproc,
849850
recvproc,
850-
world)) {};
851+
world,
852+
nprocs)) {};
851853
~InputNlist() { DP_DeleteNlist(nl); };
852854
/// @brief C API neighbor list.
853855
DP_Nlist* nl;

source/api_c/src/c_api.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,11 @@ DP_Nlist* DP_NewNlist_comm(int inum_,
3535
int** sendlist,
3636
int* sendproc,
3737
int* recvproc,
38-
void* world) {
38+
void* world,
39+
int nprocs) {
3940
deepmd::InputNlist nl(inum_, ilist_, numneigh_, firstneigh_, nswap, sendnum,
40-
recvnum, firstrecv, sendlist, sendproc, recvproc,
41-
world);
41+
recvnum, firstrecv, sendlist, sendproc, recvproc, world,
42+
nprocs);
4243
DP_Nlist* new_nl = new DP_Nlist(nl);
4344
return new_nl;
4445
}

source/api_cc/include/DeepPotPTExpt.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,15 @@ class DeepPotPTExpt : public DeepPotBackend {
226226
// passing. ``with_comm_tempfile_`` owns the extracted nested .pt2
227227
// for the lifetime of ``with_comm_loader``.
228228
bool has_comm_artifact_ = false;
229+
// Whether the regular .pt2 graph consumes the mapping tensor for
230+
// ghost-feature gather (true for any message-passing descriptor:
231+
// DPA2/DPA3/hybrids; false for se_e2_a/DPA1/etc.). Mirrors the
232+
// descriptor's ``has_message_passing()`` API; read from the
233+
// ``has_message_passing`` metadata field. Defaults to false for
234+
// pre-PR .pt2 archives that lack the field so non-GNN archives
235+
// continue to work; GNN archives must be regenerated to opt into
236+
// the fail-fast guard against the silent-corruption bug.
237+
bool has_message_passing_ = false;
229238
std::unique_ptr<deepmd::ptexpt::TempFile> with_comm_tempfile_;
230239
std::unique_ptr<torch::inductor::AOTIModelPackageLoader> with_comm_loader;
231240

source/api_cc/include/DeepSpinPTExpt.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,9 @@ class DeepSpinPTExpt : public DeepSpinBackend {
196196
std::unique_ptr<torch::inductor::AOTIModelPackageLoader> loader;
197197
// Optional with-comm artifact for multi-rank GNN spin inference.
198198
bool has_comm_artifact_ = false;
199+
// Mirrors descriptor's has_message_passing(). See DeepPotPTExpt.h
200+
// for the full rationale and gating role.
201+
bool has_message_passing_ = false;
199202
std::unique_ptr<deepmd::ptexpt::TempFile> with_comm_tempfile_;
200203
std::unique_ptr<torch::inductor::AOTIModelPackageLoader> with_comm_loader;
201204

source/api_cc/src/DeepPotPTExpt.cc

Lines changed: 70 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,18 @@ void DeepPotPTExpt::init(const std::string& model,
172172
// exchange and producing wrong results.
173173
has_comm_artifact_ = metadata.obj_val.count("has_comm_artifact") &&
174174
metadata["has_comm_artifact"].as_bool();
175+
// Whether the regular .pt2 graph consumes ``mapping`` for ghost-atom
176+
// feature gather. Mirrors the descriptor's ``has_message_passing()``
177+
// API: true for message-passing descriptors (DPA2, DPA3, hybrids
178+
// over those), false for non-message-passing descriptors (se_e2_a,
179+
// DPA1, etc.). Pre-PR .pt2 archives lack this field; default to
180+
// false so they retain their previous behaviour (non-GNN archives
181+
// continue to work; GNN archives that had the original
182+
// silent-corruption bug must be regenerated to opt into the fail-
183+
// fast guard). All in-tree fixtures are regenerated by the gen
184+
// scripts and carry the explicit value.
185+
has_message_passing_ = metadata.obj_val.count("has_message_passing") &&
186+
metadata["has_message_passing"].as_bool();
175187
if (has_comm_artifact_) {
176188
try {
177189
// Extract the nested ``extra/forward_lower_with_comm.pt2`` into a
@@ -353,6 +365,51 @@ void DeepPotPTExpt::compute(ENERGYVTYPE& ener,
353365
.clone()
354366
.to(device);
355367

368+
// Dispatch decision: use the with-comm artifact when LAMMPS is running
369+
// multi-rank. ``lmp_list.nprocs > 1`` is the direct predicate;
370+
// LAMMPS pair styles populate it by passing ``comm->nprocs`` to the
371+
// ``InputNlist`` constructor. Earlier drafts used ``nswap > 0`` as a
372+
// proxy, but that breaks for ``atom_style spin`` (which emits
373+
// nswap > 0 even in single-rank to propagate PBC ghost spins).
374+
// ``nprocs`` is unambiguous.
375+
//
376+
// The regular artifact uses ``mapping`` to gather ghost-atom features
377+
// from local-atom embeddings (``index_select(node_ebd[1, nloc, dim],
378+
// mapping)``). Identity-mapping for ghost slots is silently wrong,
379+
// so fail-fast when the regular path would be taken without a real
380+
// mapping — applies uniformly to every caller (LAMMPS pair, ctest
381+
// fixtures, direct C++ API users). Callers that want the regular
382+
// path must populate ``lmp_list.mapping``.
383+
bool multi_rank = (lmp_list.nprocs > 1);
384+
bool atom_map_present = (lmp_list.mapping != nullptr);
385+
bool use_with_comm = has_comm_artifact_ && multi_rank;
386+
// Decision matrix (see PR #5450 description):
387+
// non-GNN model (has_message_passing_ == false): regular path is
388+
// always safe.
389+
// nghost == 0 (NoPbc, isolated cluster): always safe.
390+
// GNN model, multi-rank: requires has_comm_artifact_ (cell C-mr / D-mr)
391+
// else fail-fast (cell B-mr)
392+
// GNN model, single-rank: requires atom_map_present (cell A / C)
393+
// else fail-fast (cell B / D)
394+
if (has_message_passing_ && nghost > 0) {
395+
if (multi_rank && !has_comm_artifact_) {
396+
throw deepmd::deepmd_exception(
397+
"Multi-rank LAMMPS .pt2 inference requires the model to be "
398+
"exported with `use_loc_mapping=False`, which compiles a "
399+
"with-comm artifact for cross-rank ghost-feature exchange. "
400+
"Re-export the model with use_loc_mapping=False and try again.");
401+
}
402+
if (!multi_rank && !atom_map_present) {
403+
throw deepmd::deepmd_exception(
404+
"Single-rank LAMMPS .pt2 inference requires `atom_modify map "
405+
"yes` in the LAMMPS input (so InputNlist.mapping is populated "
406+
"from the LAMMPS atom-map). The model gathers ghost-atom "
407+
"features via this mapping; without it the C++ side has no "
408+
"safe way to resolve ghost indices to local owners. C++ API "
409+
"callers must set inlist.mapping explicitly before compute().");
410+
}
411+
}
412+
356413
// LAMMPS sets ago=0 on every nlist rebuild (neighbor rebuild, re-partition,
357414
// atom exchange between subdomains), so `ago > 0` implies the cached
358415
// mapping and nlist tensors are still valid. Rebuild only on ago==0.
@@ -372,7 +429,15 @@ void DeepPotPTExpt::compute(ENERGYVTYPE& ener,
372429
.clone()
373430
.to(device);
374431
} else {
375-
// Default identity mapping for local atoms
432+
// Identity fallback. The fail-fast above guarantees we only
433+
// reach this branch when one of these is true:
434+
// - The model is non-message-passing (mapping is unused).
435+
// - ``nghost == 0`` (no ghosts to gather, identity is trivially
436+
// correct).
437+
// - ``use_with_comm`` is true (the with-comm graph fills ghost
438+
// features via border_op and ignores this tensor for ghost
439+
// gather — see deepmd/pt_expt/descriptor/
440+
// repflows.py::_exchange_ghosts).
376441
std::vector<std::int64_t> mapping(nall_real);
377442
for (int ii = 0; ii < nall_real; ii++) {
378443
mapping[ii] = ii;
@@ -428,14 +493,11 @@ void DeepPotPTExpt::compute(ENERGYVTYPE& ener,
428493
aparam_tensor = torch::zeros({0}, options).to(device);
429494
}
430495

431-
// Phase 4 dispatch: use the with-comm artifact when LAMMPS is
432-
// running multi-rank. ``lmp_list.nswap > 0`` is the proxy for
433-
// "multi-rank with cross-domain communication"; in single-rank
434-
// mode LAMMPS sets nswap=0. Falling back to the regular artifact
435-
// for nswap=0 is correct because that artifact uses the mapping
436-
// tensor to gather ghost embeddings from local atoms.
496+
// ``use_with_comm`` was computed earlier alongside the fail-fast
497+
// dispatch check. Use the with-comm artifact for the multi-rank case
498+
// (the regular artifact uses the mapping tensor to gather ghost
499+
// embeddings, which only works in single-rank).
437500
std::vector<torch::Tensor> flat_outputs;
438-
bool use_with_comm = has_comm_artifact_ && lmp_list.nswap > 0;
439501
if (use_with_comm && !with_comm_loader) {
440502
throw deepmd::deepmd_exception(
441503
"Multi-rank LAMMPS requires the with-comm artifact, but it failed "

source/api_cc/src/DeepSpinPTExpt.cc

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,10 @@ void DeepSpinPTExpt::init(const std::string& model,
179179
// dropping the MPI exchange.
180180
has_comm_artifact_ = metadata.obj_val.count("has_comm_artifact") &&
181181
metadata["has_comm_artifact"].as_bool();
182+
// See DeepPotPTExpt::init for rationale. Defaults to false for
183+
// pre-PR archives so they retain their previous behaviour.
184+
has_message_passing_ = metadata.obj_val.count("has_message_passing") &&
185+
metadata["has_message_passing"].as_bool();
182186
if (has_comm_artifact_) {
183187
try {
184188
with_comm_tempfile_ = std::make_unique<deepmd::ptexpt::TempFile>(
@@ -372,6 +376,46 @@ void DeepSpinPTExpt::compute(ENERGYVTYPE& ener,
372376
.clone()
373377
.to(device);
374378

379+
// Dispatch decision: see DeepPotPTExpt.cc for the full rationale.
380+
// Single-rank without atom-map cannot drive the regular path (no safe
381+
// ghost→local mapping); multi-rank without a with-comm artifact cannot
382+
// drive border_op (no inter-rank exchange tensor). Both unsupported
383+
// combinations fail-fast for every caller.
384+
// ``nprocs > 1`` is the direct multi-rank predicate (LAMMPS pair
385+
// styles set it by passing ``comm->nprocs`` to the ``InputNlist``
386+
// constructor). Earlier drafts used ``nswap > 0`` as a proxy, but
387+
// atom_style spin emits nswap > 0 even in single-rank, so the proxy
388+
// is unsound.
389+
bool multi_rank = (lmp_list.nprocs > 1);
390+
bool atom_map_present = (lmp_list.mapping != nullptr);
391+
bool use_with_comm = has_comm_artifact_ && multi_rank;
392+
// Decision matrix (see PR #5450 description):
393+
// non-GNN model (has_message_passing_ == false): regular path is
394+
// always safe.
395+
// nghost == 0 (NoPbc, isolated cluster): always safe.
396+
// GNN model, multi-rank: requires has_comm_artifact_ (cell C-mr / D-mr)
397+
// else fail-fast (cell B-mr)
398+
// GNN model, single-rank: requires atom_map_present (cell A / C)
399+
// else fail-fast (cell B / D)
400+
if (has_message_passing_ && nghost > 0) {
401+
if (multi_rank && !has_comm_artifact_) {
402+
throw deepmd::deepmd_exception(
403+
"Multi-rank LAMMPS .pt2 inference requires the model to be "
404+
"exported with `use_loc_mapping=False`, which compiles a "
405+
"with-comm artifact for cross-rank ghost-feature exchange. "
406+
"Re-export the model with use_loc_mapping=False and try again.");
407+
}
408+
if (!multi_rank && !atom_map_present) {
409+
throw deepmd::deepmd_exception(
410+
"Single-rank LAMMPS .pt2 inference requires `atom_modify map "
411+
"yes` in the LAMMPS input (so InputNlist.mapping is populated "
412+
"from the LAMMPS atom-map). The model gathers ghost-atom "
413+
"features via this mapping; without it the C++ side has no "
414+
"safe way to resolve ghost indices to local owners. C++ API "
415+
"callers must set inlist.mapping explicitly before compute().");
416+
}
417+
}
418+
375419
// LAMMPS sets ago=0 on every nlist rebuild, so ago>0 implies the cached
376420
// mapping and nlist tensors are still valid — see DeepPotPTExpt.cc for
377421
// the same rationale.
@@ -391,6 +435,11 @@ void DeepSpinPTExpt::compute(ENERGYVTYPE& ener,
391435
.clone()
392436
.to(device);
393437
} else {
438+
// Identity fallback. See DeepPotPTExpt::compute_inner for the
439+
// invariant rationale: this branch is only reached when the
440+
// model is non-message-passing, nghost==0, or use_with_comm is
441+
// true (border_op fills ghosts); other configurations were
442+
// rejected by the fail-fast above.
394443
std::vector<std::int64_t> mapping(nall_real);
395444
for (int ii = 0; ii < nall_real; ii++) {
396445
mapping[ii] = ii;
@@ -452,8 +501,10 @@ void DeepSpinPTExpt::compute(ENERGYVTYPE& ener,
452501
// _with_comm), so C++ supplies the same 8 comm tensors as the
453502
// non-spin path. ``nlocal``/``nghost`` carry the real-atom counts
454503
// (pre atom-doubling); the spin override halves them internally.
504+
//
505+
// ``use_with_comm`` was computed earlier alongside the fail-fast
506+
// dispatch check.
455507
std::vector<torch::Tensor> flat_outputs;
456-
bool use_with_comm = has_comm_artifact_ && lmp_list.nswap > 0;
457508
if (use_with_comm && !with_comm_loader) {
458509
throw deepmd::deepmd_exception(
459510
"Multi-rank LAMMPS requires the with-comm artifact, but it failed "

source/lib/include/neighbor_list.h

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,15 @@ struct InputNlist {
4646
int mask = 0xFFFFFFFF;
4747
/// mapping from all atoms to real atoms, in the size of nall
4848
int* mapping = nullptr;
49+
/// number of MPI ranks (1 = single-rank). Settable only via the
50+
/// trailing ``nprocs_`` argument of the comm-aware constructor (LAMMPS
51+
/// pair styles pass ``comm->nprocs``). The lightweight constructors
52+
/// leave it at 1 by construction — they carry no comm metadata
53+
/// (``world``, ``sendlist``, ...), so they cannot drive the with-comm
54+
/// dispatch path even if a non-1 value were forced here. Use this —
55+
/// NOT ``nswap > 0`` — as the "is multi-rank?" predicate: ``atom_style
56+
/// spin`` populates ``nswap`` even in single-rank.
57+
int nprocs = 1;
4958
InputNlist()
5059
: inum(0),
5160
ilist(NULL),
@@ -83,7 +92,8 @@ struct InputNlist {
8392
int** sendlist,
8493
int* sendproc,
8594
int* recvproc,
86-
void* world)
95+
void* world,
96+
int nprocs_ = 1)
8797
: inum(inum_),
8898
ilist(ilist_),
8999
numneigh(numneigh_),
@@ -95,7 +105,8 @@ struct InputNlist {
95105
sendlist(sendlist),
96106
sendproc(sendproc),
97107
recvproc(recvproc),
98-
world(world) {};
108+
world(world),
109+
nprocs(nprocs_) {};
99110
~InputNlist() {};
100111
/**
101112
* @brief Set mask for this neighbor list.

source/lmp/pair_deepmd.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ void PairDeepMD::compute(int eflag, int vflag) {
237237
list->inum, list->ilist, list->numneigh, list->firstneigh,
238238
commdata_->nswap, commdata_->sendnum, commdata_->recvnum,
239239
commdata_->firstrecv, commdata_->sendlist, commdata_->sendproc,
240-
commdata_->recvproc, &world);
240+
commdata_->recvproc, &world, comm->nprocs);
241241
lmp_list.set_mask(NEIGHMASK);
242242
if (comm->nprocs == 1 && atom->map_style != Atom::MAP_NONE) {
243243
lmp_list.set_mapping(mapping_vec.data());

0 commit comments

Comments
 (0)