Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions deepmd/pt/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,8 +271,8 @@ def forward_atomic(
"""
nframes, nloc, nnei = nlist.shape
atype = extended_atype[:, :nloc]
if self.do_grad_r() or self.do_grad_c():
extended_coord.requires_grad_(True)
if (self.do_grad_r() or self.do_grad_c()) and not extended_coord.requires_grad:
extended_coord = extended_coord.clone().requires_grad_(True)

# Handle default chg_spin if descriptor supports it
if self.add_chg_spin_ebd and charge_spin is None:
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/model/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def forward_atomic(
the result dict, defined by the fitting net output def.
"""
nframes, nloc, nnei = nlist.shape
if self.do_grad_r() or self.do_grad_c():
if (self.do_grad_r() or self.do_grad_c()) and not extended_coord.requires_grad:
extended_coord.requires_grad_(True)
extended_coord = extended_coord.view(nframes, -1, 3)
sorted_rcuts, sorted_sels = self._sort_rcuts_sels()
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/model/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def forward_atomic(
) -> dict[str, torch.Tensor]:
nframes, nloc, nnei = nlist.shape
extended_coord = extended_coord.view(nframes, -1, 3)
if self.do_grad_r() or self.do_grad_c():
if (self.do_grad_r() or self.do_grad_c()) and not extended_coord.requires_grad:
extended_coord.requires_grad_(True)

# this will mask all -1 in the nlist
Expand Down
8 changes: 6 additions & 2 deletions deepmd/pt/model/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,8 +306,12 @@ def forward_common_lower(
extended_coord, fparam=fparam, aparam=aparam
)
del extended_coord, fparam, aparam
force_coord = cc_ext
if self.atomic_model.do_grad_r() or self.atomic_model.do_grad_c():
if not force_coord.requires_grad:
force_coord = force_coord.clone().requires_grad_(True)
atomic_ret = self.atomic_model.forward_common_atomic(
cc_ext,
force_coord,
extended_atype,
nlist,
mapping=mapping,
Expand All @@ -319,7 +323,7 @@ def forward_common_lower(
model_predict = fit_output_to_model_output(
atomic_ret,
self.atomic_output_def(),
cc_ext,
force_coord,
do_atomic_virial=do_atomic_virial,
create_graph=self.training,
mask=atomic_ret["mask"] if "mask" in atomic_ret else None,
Expand Down
26 changes: 26 additions & 0 deletions source/tests/pt/model/test_dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,32 @@ def test_self_consistency(self) -> None:
to_numpy_array(ret1["energy"]),
)

def test_forward_common_atomic_accepts_leaf_view_input(self) -> None:
ds = DescrptSeA(
self.rcut,
self.rcut_smth,
self.sel,
).to(env.DEVICE)
ft = InvarFitting(
"energy",
self.nt,
ds.get_dim_out(),
1,
mixed_types=ds.mixed_types(),
).to(env.DEVICE)
md0 = DPAtomicModel(ds, ft, type_map=["foo", "bar"]).to(env.DEVICE)

coord = to_torch_tensor(self.coord_ext)
coord_view = coord.view(self.nf, self.nall, 3)
args = [
coord_view,
to_torch_tensor(self.atype_ext),
to_torch_tensor(self.nlist),
]
ret = md0.forward_common_atomic(*args)

self.assertIn("energy", ret)

def test_dp_consistency(self) -> None:
nf, nloc, nnei = self.nlist.shape
ds = DPDescrptSeA(
Expand Down
36 changes: 36 additions & 0 deletions source/tests/pt/model/test_dp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,42 @@ def test_self_consistency(self) -> None:
atol=self.atol,
)

def test_forward_lower_accepts_leaf_view_input(self) -> None:
ds = DescrptSeA(
self.rcut,
self.rcut_smth,
self.sel,
).to(env.DEVICE)
ft = EnergyFittingNet(
self.nt,
ds.get_dim_out(),
mixed_types=ds.mixed_types(),
).to(env.DEVICE)
type_map = ["foo", "bar"]
md0 = EnergyModel(ds, ft, type_map=type_map).to(env.DEVICE)

coord_ext, atype_ext, _ = extend_coord_with_ghosts(
to_torch_tensor(self.coord),
to_torch_tensor(self.atype),
to_torch_tensor(self.cell),
self.rcut,
)
nlist = build_neighbor_list(
coord_ext,
atype_ext,
self.nloc,
self.rcut,
self.sel,
distinguish_types=(not md0.mixed_types()),
)
coord_view = coord_ext.view(self.nf, -1, 3)

ret = md0.forward_lower(coord_view, atype_ext, nlist, do_atomic_virial=True)

self.assertFalse(coord_view.requires_grad)
self.assertIn("extended_force", ret)
self.assertIn("virial", ret)

def test_dp_consistency(self) -> None:
nf, nloc = self.atype.shape
nfp, nap = 2, 3
Expand Down
Loading