Skip to content
35 changes: 35 additions & 0 deletions deepmd/pt_expt/utils/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,41 @@ def _collect_metadata(model: torch.nn.Module, is_spin: bool = False) -> dict:
# (per-layer ghost-feature MPI exchange via deepmd_export::border_op).
# The C++ DeepPotPTExpt / DeepSpinPTExpt loaders branch on this flag.
meta["has_comm_artifact"] = _needs_with_comm_artifact(model)

# Whether the model's regular .pt2 graph consumes the ``mapping``
# tensor to gather per-layer ghost-atom features from local atoms.
# Mirrors the descriptor's ``has_message_passing()`` API: True for
# any message-passing descriptor (DPA2, DPA3, hybrids over those);
# False for non-message-passing descriptors (se_e2_a, DPA1, etc.).
# The C++ side gates its fail-fast on this — an absent mapping is
# fatal only for models that would silently corrupt ghost features
# otherwise.
#
# Lookup order: model -> atomic_model -> descriptor. Going through
# ``atomic_model.has_message_passing()`` is important for composite
# atomic models (e.g. ``LinearAtomicModel`` in DP-ZBL) which don't
# expose a single ``.descriptor`` but do aggregate the flag across
# their sub-models. ``descriptor.has_message_passing()`` is the
# fallback for any future wrapper that lacks the higher-level
# methods.
def _probe_has_message_passing(obj: object) -> bool | None:
if obj is None or not hasattr(obj, "has_message_passing"):
return None
try:
return bool(obj.has_message_passing())
except (AttributeError, NotImplementedError):
return None

result: bool | None = None
for obj in (
model,
getattr(model, "atomic_model", None),
getattr(getattr(model, "atomic_model", None), "descriptor", None),
):
result = _probe_has_message_passing(obj)
if result is not None:
break
meta["has_message_passing"] = result if result is not None else False
return meta


Expand Down
6 changes: 5 additions & 1 deletion source/api_c/include/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ extern DP_Nlist* DP_NewNlist(int inum_,
* each swap.
* @param[in] world Pointer to the MPI communicator or similar communication
* world used for the operation.
* @param[in] nprocs Number of MPI ranks (1 = single-rank). Used by
* ``DeepPotPTExpt`` / ``DeepSpinPTExpt`` to choose between the regular
* and with-comm artifacts. Defaults to 1 if not supplied.
* @returns A pointer to the initialized neighbor list with communication
* capabilities.
*/
Expand All @@ -66,7 +69,8 @@ extern DP_Nlist* DP_NewNlist_comm(int inum_,
int** sendlist,
int* sendproc,
int* recvproc,
void* world);
void* world,
int nprocs);

/**
* @brief Set mask for a neighbor list.
Expand Down
6 changes: 4 additions & 2 deletions source/api_c/include/deepmd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -831,7 +831,8 @@ struct InputNlist {
int** sendlist,
int* sendproc,
int* recvproc,
void* world)
void* world,
int nprocs = 1)
: inum(inum_),
ilist(ilist_),
numneigh(numneigh_),
Expand All @@ -847,7 +848,8 @@ struct InputNlist {
sendlist,
sendproc,
recvproc,
world)) {};
world,
nprocs)) {};
~InputNlist() { DP_DeleteNlist(nl); };
/// @brief C API neighbor list.
DP_Nlist* nl;
Expand Down
7 changes: 4 additions & 3 deletions source/api_c/src/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,11 @@ DP_Nlist* DP_NewNlist_comm(int inum_,
int** sendlist,
int* sendproc,
int* recvproc,
void* world) {
void* world,
int nprocs) {
deepmd::InputNlist nl(inum_, ilist_, numneigh_, firstneigh_, nswap, sendnum,
recvnum, firstrecv, sendlist, sendproc, recvproc,
world);
recvnum, firstrecv, sendlist, sendproc, recvproc, world,
nprocs);
DP_Nlist* new_nl = new DP_Nlist(nl);
return new_nl;
}
Expand Down
9 changes: 9 additions & 0 deletions source/api_cc/include/DeepPotPTExpt.h
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,15 @@ class DeepPotPTExpt : public DeepPotBackend {
// passing. ``with_comm_tempfile_`` owns the extracted nested .pt2
// for the lifetime of ``with_comm_loader``.
bool has_comm_artifact_ = false;
// Whether the regular .pt2 graph consumes the mapping tensor for
// ghost-feature gather (true for any message-passing descriptor:
// DPA2/DPA3/hybrids; false for se_e2_a/DPA1/etc.). Mirrors the
// descriptor's ``has_message_passing()`` API; read from the
// ``has_message_passing`` metadata field. Defaults to false for
// pre-PR .pt2 archives that lack the field so non-GNN archives
// continue to work; GNN archives must be regenerated to opt into
// the fail-fast guard against the silent-corruption bug.
bool has_message_passing_ = false;
std::unique_ptr<deepmd::ptexpt::TempFile> with_comm_tempfile_;
std::unique_ptr<torch::inductor::AOTIModelPackageLoader> with_comm_loader;

Expand Down
3 changes: 3 additions & 0 deletions source/api_cc/include/DeepSpinPTExpt.h
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,9 @@ class DeepSpinPTExpt : public DeepSpinBackend {
std::unique_ptr<torch::inductor::AOTIModelPackageLoader> loader;
// Optional with-comm artifact for multi-rank GNN spin inference.
bool has_comm_artifact_ = false;
// Mirrors descriptor's has_message_passing(). See DeepPotPTExpt.h
// for the full rationale and gating role.
bool has_message_passing_ = false;
std::unique_ptr<deepmd::ptexpt::TempFile> with_comm_tempfile_;
std::unique_ptr<torch::inductor::AOTIModelPackageLoader> with_comm_loader;

Expand Down
78 changes: 70 additions & 8 deletions source/api_cc/src/DeepPotPTExpt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,18 @@ void DeepPotPTExpt::init(const std::string& model,
// exchange and producing wrong results.
has_comm_artifact_ = metadata.obj_val.count("has_comm_artifact") &&
metadata["has_comm_artifact"].as_bool();
// Whether the regular .pt2 graph consumes ``mapping`` for ghost-atom
// feature gather. Mirrors the descriptor's ``has_message_passing()``
// API: true for message-passing descriptors (DPA2, DPA3, hybrids
// over those), false for non-message-passing descriptors (se_e2_a,
// DPA1, etc.). Pre-PR .pt2 archives lack this field; default to
// false so they retain their previous behaviour (non-GNN archives
// continue to work; GNN archives that had the original
// silent-corruption bug must be regenerated to opt into the fail-
// fast guard). All in-tree fixtures are regenerated by the gen
// scripts and carry the explicit value.
has_message_passing_ = metadata.obj_val.count("has_message_passing") &&
metadata["has_message_passing"].as_bool();
if (has_comm_artifact_) {
try {
// Extract the nested ``extra/forward_lower_with_comm.pt2`` into a
Expand Down Expand Up @@ -353,6 +365,51 @@ void DeepPotPTExpt::compute(ENERGYVTYPE& ener,
.clone()
.to(device);

// Dispatch decision: use the with-comm artifact when LAMMPS is running
// multi-rank. ``lmp_list.nprocs > 1`` is the direct predicate;
// LAMMPS pair styles populate it by passing ``comm->nprocs`` to the
// ``InputNlist`` constructor. Earlier drafts used ``nswap > 0`` as a
// proxy, but that breaks for ``atom_style spin`` (which emits
// nswap > 0 even in single-rank to propagate PBC ghost spins).
// ``nprocs`` is unambiguous.
//
// The regular artifact uses ``mapping`` to gather ghost-atom features
// from local-atom embeddings (``index_select(node_ebd[1, nloc, dim],
// mapping)``). Identity-mapping for ghost slots is silently wrong,
// so fail-fast when the regular path would be taken without a real
// mapping — applies uniformly to every caller (LAMMPS pair, ctest
// fixtures, direct C++ API users). Callers that want the regular
// path must populate ``lmp_list.mapping``.
bool multi_rank = (lmp_list.nprocs > 1);
bool atom_map_present = (lmp_list.mapping != nullptr);
bool use_with_comm = has_comm_artifact_ && multi_rank;
// Decision matrix (see PR #5450 description):
// non-GNN model (has_message_passing_ == false): regular path is
// always safe.
// nghost == 0 (NoPbc, isolated cluster): always safe.
// GNN model, multi-rank: requires has_comm_artifact_ (cell C-mr / D-mr)
// else fail-fast (cell B-mr)
// GNN model, single-rank: requires atom_map_present (cell A / C)
// else fail-fast (cell B / D)
if (has_message_passing_ && nghost > 0) {
if (multi_rank && !has_comm_artifact_) {
throw deepmd::deepmd_exception(
"Multi-rank LAMMPS .pt2 inference requires the model to be "
"exported with `use_loc_mapping=False`, which compiles a "
"with-comm artifact for cross-rank ghost-feature exchange. "
"Re-export the model with use_loc_mapping=False and try again.");
}
if (!multi_rank && !atom_map_present) {
throw deepmd::deepmd_exception(
"Single-rank LAMMPS .pt2 inference requires `atom_modify map "
"yes` in the LAMMPS input (so InputNlist.mapping is populated "
"from the LAMMPS atom-map). The model gathers ghost-atom "
"features via this mapping; without it the C++ side has no "
"safe way to resolve ghost indices to local owners. C++ API "
"callers must set inlist.mapping explicitly before compute().");
}
}

// LAMMPS sets ago=0 on every nlist rebuild (neighbor rebuild, re-partition,
// atom exchange between subdomains), so `ago > 0` implies the cached
// mapping and nlist tensors are still valid. Rebuild only on ago==0.
Expand All @@ -372,7 +429,15 @@ void DeepPotPTExpt::compute(ENERGYVTYPE& ener,
.clone()
.to(device);
} else {
// Default identity mapping for local atoms
// Identity fallback. The fail-fast above guarantees we only
// reach this branch when one of these is true:
// - The model is non-message-passing (mapping is unused).
// - ``nghost == 0`` (no ghosts to gather, identity is trivially
// correct).
// - ``use_with_comm`` is true (the with-comm graph fills ghost
// features via border_op and ignores this tensor for ghost
// gather — see deepmd/pt_expt/descriptor/
// repflows.py::_exchange_ghosts).
std::vector<std::int64_t> mapping(nall_real);
for (int ii = 0; ii < nall_real; ii++) {
mapping[ii] = ii;
Expand Down Expand Up @@ -428,14 +493,11 @@ void DeepPotPTExpt::compute(ENERGYVTYPE& ener,
aparam_tensor = torch::zeros({0}, options).to(device);
}

// Phase 4 dispatch: use the with-comm artifact when LAMMPS is
// running multi-rank. ``lmp_list.nswap > 0`` is the proxy for
// "multi-rank with cross-domain communication"; in single-rank
// mode LAMMPS sets nswap=0. Falling back to the regular artifact
// for nswap=0 is correct because that artifact uses the mapping
// tensor to gather ghost embeddings from local atoms.
// ``use_with_comm`` was computed earlier alongside the fail-fast
// dispatch check. Use the with-comm artifact for the multi-rank case
// (the regular artifact uses the mapping tensor to gather ghost
// embeddings, which only works in single-rank).
std::vector<torch::Tensor> flat_outputs;
bool use_with_comm = has_comm_artifact_ && lmp_list.nswap > 0;
if (use_with_comm && !with_comm_loader) {
throw deepmd::deepmd_exception(
"Multi-rank LAMMPS requires the with-comm artifact, but it failed "
Expand Down
53 changes: 52 additions & 1 deletion source/api_cc/src/DeepSpinPTExpt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,10 @@ void DeepSpinPTExpt::init(const std::string& model,
// dropping the MPI exchange.
has_comm_artifact_ = metadata.obj_val.count("has_comm_artifact") &&
metadata["has_comm_artifact"].as_bool();
// See DeepPotPTExpt::init for rationale. Defaults to false for
// pre-PR archives so they retain their previous behaviour.
has_message_passing_ = metadata.obj_val.count("has_message_passing") &&
metadata["has_message_passing"].as_bool();
if (has_comm_artifact_) {
try {
with_comm_tempfile_ = std::make_unique<deepmd::ptexpt::TempFile>(
Expand Down Expand Up @@ -372,6 +376,46 @@ void DeepSpinPTExpt::compute(ENERGYVTYPE& ener,
.clone()
.to(device);

// Dispatch decision: see DeepPotPTExpt.cc for the full rationale.
// Single-rank without atom-map cannot drive the regular path (no safe
// ghost→local mapping); multi-rank without a with-comm artifact cannot
// drive border_op (no inter-rank exchange tensor). Both unsupported
// combinations fail-fast for every caller.
// ``nprocs > 1`` is the direct multi-rank predicate (LAMMPS pair
// styles set it by passing ``comm->nprocs`` to the ``InputNlist``
// constructor). Earlier drafts used ``nswap > 0`` as a proxy, but
// atom_style spin emits nswap > 0 even in single-rank, so the proxy
// is unsound.
bool multi_rank = (lmp_list.nprocs > 1);
bool atom_map_present = (lmp_list.mapping != nullptr);
bool use_with_comm = has_comm_artifact_ && multi_rank;
// Decision matrix (see PR #5450 description):
// non-GNN model (has_message_passing_ == false): regular path is
// always safe.
// nghost == 0 (NoPbc, isolated cluster): always safe.
// GNN model, multi-rank: requires has_comm_artifact_ (cell C-mr / D-mr)
// else fail-fast (cell B-mr)
// GNN model, single-rank: requires atom_map_present (cell A / C)
// else fail-fast (cell B / D)
if (has_message_passing_ && nghost > 0) {
if (multi_rank && !has_comm_artifact_) {
throw deepmd::deepmd_exception(
"Multi-rank LAMMPS .pt2 inference requires the model to be "
"exported with `use_loc_mapping=False`, which compiles a "
"with-comm artifact for cross-rank ghost-feature exchange. "
"Re-export the model with use_loc_mapping=False and try again.");
}
if (!multi_rank && !atom_map_present) {
throw deepmd::deepmd_exception(
"Single-rank LAMMPS .pt2 inference requires `atom_modify map "
"yes` in the LAMMPS input (so InputNlist.mapping is populated "
"from the LAMMPS atom-map). The model gathers ghost-atom "
"features via this mapping; without it the C++ side has no "
"safe way to resolve ghost indices to local owners. C++ API "
"callers must set inlist.mapping explicitly before compute().");
}
}

// LAMMPS sets ago=0 on every nlist rebuild, so ago>0 implies the cached
// mapping and nlist tensors are still valid — see DeepPotPTExpt.cc for
// the same rationale.
Expand All @@ -391,6 +435,11 @@ void DeepSpinPTExpt::compute(ENERGYVTYPE& ener,
.clone()
.to(device);
} else {
// Identity fallback. See DeepPotPTExpt::compute_inner for the
// invariant rationale: this branch is only reached when the
// model is non-message-passing, nghost==0, or use_with_comm is
// true (border_op fills ghosts); other configurations were
// rejected by the fail-fast above.
std::vector<std::int64_t> mapping(nall_real);
for (int ii = 0; ii < nall_real; ii++) {
mapping[ii] = ii;
Expand Down Expand Up @@ -452,8 +501,10 @@ void DeepSpinPTExpt::compute(ENERGYVTYPE& ener,
// _with_comm), so C++ supplies the same 8 comm tensors as the
// non-spin path. ``nlocal``/``nghost`` carry the real-atom counts
// (pre atom-doubling); the spin override halves them internally.
//
// ``use_with_comm`` was computed earlier alongside the fail-fast
// dispatch check.
std::vector<torch::Tensor> flat_outputs;
bool use_with_comm = has_comm_artifact_ && lmp_list.nswap > 0;
if (use_with_comm && !with_comm_loader) {
throw deepmd::deepmd_exception(
"Multi-rank LAMMPS requires the with-comm artifact, but it failed "
Expand Down
15 changes: 13 additions & 2 deletions source/lib/include/neighbor_list.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,15 @@ struct InputNlist {
int mask = 0xFFFFFFFF;
/// mapping from all atoms to real atoms, in the size of nall
int* mapping = nullptr;
/// number of MPI ranks (1 = single-rank). Settable only via the
/// trailing ``nprocs_`` argument of the comm-aware constructor (LAMMPS
/// pair styles pass ``comm->nprocs``). The lightweight constructors
/// leave it at 1 by construction — they carry no comm metadata
/// (``world``, ``sendlist``, ...), so they cannot drive the with-comm
/// dispatch path even if a non-1 value were forced here. Use this —
/// NOT ``nswap > 0`` — as the "is multi-rank?" predicate: ``atom_style
/// spin`` populates ``nswap`` even in single-rank.
int nprocs = 1;
InputNlist()
: inum(0),
ilist(NULL),
Expand Down Expand Up @@ -83,7 +92,8 @@ struct InputNlist {
int** sendlist,
int* sendproc,
int* recvproc,
void* world)
void* world,
int nprocs_ = 1)
: inum(inum_),
ilist(ilist_),
numneigh(numneigh_),
Expand All @@ -95,7 +105,8 @@ struct InputNlist {
sendlist(sendlist),
sendproc(sendproc),
recvproc(recvproc),
world(world) {};
world(world),
nprocs(nprocs_) {};
~InputNlist() {};
/**
* @brief Set mask for this neighbor list.
Expand Down
2 changes: 1 addition & 1 deletion source/lmp/pair_deepmd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ void PairDeepMD::compute(int eflag, int vflag) {
list->inum, list->ilist, list->numneigh, list->firstneigh,
commdata_->nswap, commdata_->sendnum, commdata_->recvnum,
commdata_->firstrecv, commdata_->sendlist, commdata_->sendproc,
commdata_->recvproc, &world);
commdata_->recvproc, &world, comm->nprocs);
lmp_list.set_mask(NEIGHMASK);
if (comm->nprocs == 1 && atom->map_style != Atom::MAP_NONE) {
lmp_list.set_mapping(mapping_vec.data());
Expand Down
14 changes: 13 additions & 1 deletion source/lmp/pair_deepspin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,15 @@ void PairDeepSpin::compute(int eflag, int vflag) {
}
}

// mapping (for DPA-2/3 .pt2 GNN models that gather ghost features via
// the LAMMPS atom-map; harmless for other models).
std::vector<int> mapping_vec(nall, -1);
if (comm->nprocs == 1 && atom->map_style != Atom::MAP_NONE) {
for (size_t ii = 0; ii < nall; ++ii) {
mapping_vec[ii] = atom->map(atom->tag[ii]);
}
}

if (do_compute_aparam) {
make_aparam_from_compute(daparam);
} else if (aparam.size() > 0) {
Expand Down Expand Up @@ -242,8 +251,11 @@ void PairDeepSpin::compute(int eflag, int vflag) {
list->inum, list->ilist, list->numneigh, list->firstneigh,
commdata_->nswap, commdata_->sendnum, commdata_->recvnum,
commdata_->firstrecv, commdata_->sendlist, commdata_->sendproc,
commdata_->recvproc, &world);
commdata_->recvproc, &world, comm->nprocs);
lmp_list.set_mask(NEIGHMASK);
if (comm->nprocs == 1 && atom->map_style != Atom::MAP_NONE) {
lmp_list.set_mapping(mapping_vec.data());
}
if (single_model || multi_models_no_mod_devi) {
// cvflag_atom is the right flag for the cvatom matrix
if (!(eflag_atom || cvflag_atom)) {
Expand Down
Loading
Loading