Skip to content

Commit 7b84ec8

Browse files
committed
add atom_modify map yes
1 parent 4021ec7 commit 7b84ec8

7 files changed

Lines changed: 46 additions & 3 deletions

File tree

deepmd/pt/entrypoints/freeze_pt2.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,22 @@ def _get_model_ntypes(model: torch.nn.Module) -> int:
7474
return int(descriptor.get_ntypes())
7575

7676

77+
def _model_has_message_passing(model: torch.nn.Module) -> bool:
78+
"""Return whether the regular .pt2 graph requires a real atom mapping."""
79+
for obj in (
80+
model,
81+
getattr(model, "atomic_model", None),
82+
model.get_descriptor() if hasattr(model, "get_descriptor") else None,
83+
):
84+
if obj is None or not hasattr(obj, "has_message_passing"):
85+
continue
86+
try:
87+
return bool(obj.has_message_passing())
88+
except (AttributeError, NotImplementedError):
89+
continue
90+
return False
91+
92+
7793
def _strip_shape_assertions(graph_module: torch.nn.Module) -> None:
7894
"""Remove deferred shape assertions from spin export graphs.
7995
@@ -243,6 +259,8 @@ def _collect_metadata(
243259
"dim_aparam": int(model.get_dim_aparam()),
244260
"dim_chg_spin": int(model.get_dim_chg_spin()),
245261
"mixed_types": bool(model.mixed_types()),
262+
"has_message_passing": _model_has_message_passing(model),
263+
"has_comm_artifact": False,
246264
"has_default_fparam": bool(model.has_default_fparam()),
247265
"default_fparam": _to_py_list(model.get_default_fparam()),
248266
"default_chg_spin": _to_py_list(model.get_default_chg_spin()),

deepmd/pt/model/model/sezm_model.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2265,7 +2265,7 @@ def post_process_output_dens(
22652265
)
22662266

22672267
# =========================================================================
2268-
# Charge/Spin Condition Metadata
2268+
# Metadata
22692269
# =========================================================================
22702270

22712271
def has_chg_spin_ebd(self) -> bool:
@@ -2284,6 +2284,10 @@ def get_default_chg_spin(self) -> torch.Tensor | None:
22842284
"""Return default charge/spin conditions as a tensor."""
22852285
return self.atomic_model.get_default_chg_spin()
22862286

2287+
def has_message_passing(self) -> bool:
2288+
"""Return whether the descriptor performs message passing."""
2289+
return self.atomic_model.has_message_passing()
2290+
22872291
# =========================================================================
22882292
# Mode Management
22892293
# =========================================================================

examples/water/dpa4/lmp/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ Step PotEng KinEng TotEng Temp
5959
entries `"O"` and `"H"` respectively. When the element names are
6060
omitted, the mapping falls back to the `type_map` order stored in
6161
the `.pt2` metadata.
62+
- `atom_modify map yes` keeps the ghost / periodic-image to local-atom
63+
mapping explicit for `.pt2` graph inference. GNN-style `.pt2` models
64+
fail fast when this atom map is required but absent.
6265
- The 500-step `pretrained.pt` is intended as a smoke test, not a
6366
physically accurate water potential. Retrain with a longer schedule
6467
for production.

examples/water/dpa4/lmp/in.lammps

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
units metal
1010
boundary p p p
1111
atom_style atomic
12+
atom_modify map yes
1213

1314
neighbor 2.0 bin
1415
neigh_modify every 10 delay 0 check no

source/api_cc/tests/test_deeppot_ptexpt.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
#include "DeepPot.h"
1313
#include "DeepPotPTExpt.h"
1414
#if defined(BUILD_PYTORCH)
15-
#include "commonPTExpt.h"
15+
#include "../src/commonPTExpt.h"
1616
#endif
1717
#include "neighbor_list.h"
1818
#include "test_utils.h"

source/lmp/pair_deepmd.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,8 @@ void PairDeepMD::compute(int eflag, int vflag) {
188188
}
189189
}
190190

191-
// mapping (for DPA-2 JAX)
191+
// mapping (for DPA-2/3 .pt2 GNN models that gather ghost features via
192+
// the LAMMPS atom-map; harmless for other models).
192193
std::vector<int> mapping_vec(nall, -1);
193194
if (comm->nprocs == 1 && atom->map_style != Atom::MAP_NONE) {
194195
for (size_t ii = 0; ii < nall; ++ii) {

source/tests/pt/model/test_sezm_export.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,8 @@ def test_archive_metadata(self) -> None:
367367
"dim_aparam",
368368
"dim_chg_spin",
369369
"mixed_types",
370+
"has_message_passing",
371+
"has_comm_artifact",
370372
"has_default_fparam",
371373
"default_chg_spin",
372374
"output_keys",
@@ -381,6 +383,8 @@ def test_archive_metadata(self) -> None:
381383
self.assertEqual(metadata["rcut"], self.params["descriptor"]["rcut"])
382384
self.assertEqual(list(metadata["sel"]), list(self.params["descriptor"]["sel"]))
383385
self.assertTrue(metadata["mixed_types"])
386+
self.assertTrue(metadata["has_message_passing"])
387+
self.assertFalse(metadata["has_comm_artifact"])
384388
self.assertFalse(metadata["is_spin"])
385389
self.assertEqual(metadata["dim_fparam"], 0)
386390
self.assertEqual(metadata["dim_aparam"], 0)
@@ -592,6 +596,16 @@ def test_metadata_records_ntypes_when_type_map_is_empty(self) -> None:
592596
self.assertEqual(metadata["type_map"], [])
593597
self.assertEqual(metadata["ntypes"], model.get_descriptor().get_ntypes())
594598

599+
def test_metadata_records_message_passing_contract(self) -> None:
600+
"""PTExpt fail-fast depends on the SeZM-specific metadata contract."""
601+
model = _build_tiny_sezm_model()
602+
603+
metadata = _collect_metadata(model, ["energy"])
604+
605+
self.assertTrue(model.has_message_passing())
606+
self.assertTrue(metadata["has_message_passing"])
607+
self.assertFalse(metadata["has_comm_artifact"])
608+
595609
def test_charge_spin_export_sample_has_runtime_input_slot(self) -> None:
596610
"""Charge/spin-conditioned exports should not bake defaults into the graph."""
597611
params = _tiny_sezm_model_params()
@@ -722,6 +736,8 @@ def fake_compile(_exported: torch.export.ExportedProgram, package_path: str):
722736
self.assertTrue(metadata["is_spin"])
723737
self.assertEqual(metadata["type_map"], params["type_map"])
724738
self.assertEqual(metadata["ntypes"], len(params["type_map"]))
739+
self.assertTrue(metadata["has_message_passing"])
740+
self.assertFalse(metadata["has_comm_artifact"])
725741
self.assertEqual(metadata["dim_chg_spin"], 0)
726742
self.assertIsNone(metadata["default_chg_spin"])
727743
self.assertEqual(metadata["use_spin"], params["spin"]["use_spin"])

0 commit comments

Comments
 (0)