Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion deepmd/pt/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def forward_atomic(
"""
nframes, nloc, nnei = nlist.shape
atype = extended_atype[:, :nloc]
if self.do_grad_r() or self.do_grad_c():
if (self.do_grad_r() or self.do_grad_c()) and not extended_coord.requires_grad:
extended_coord.requires_grad_(True)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated

# Handle default chg_spin if descriptor supports it
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/model/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def forward_atomic(
the result dict, defined by the fitting net output def.
"""
nframes, nloc, nnei = nlist.shape
if self.do_grad_r() or self.do_grad_c():
if (self.do_grad_r() or self.do_grad_c()) and not extended_coord.requires_grad:
extended_coord.requires_grad_(True)
extended_coord = extended_coord.view(nframes, -1, 3)
sorted_rcuts, sorted_sels = self._sort_rcuts_sels()
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/model/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def forward_atomic(
) -> dict[str, torch.Tensor]:
nframes, nloc, nnei = nlist.shape
extended_coord = extended_coord.view(nframes, -1, 3)
if self.do_grad_r() or self.do_grad_c():
if (self.do_grad_r() or self.do_grad_c()) and not extended_coord.requires_grad:
extended_coord.requires_grad_(True)

# this will mask all -1 in the nlist
Expand Down
8 changes: 6 additions & 2 deletions deepmd/pt/model/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,8 +306,12 @@ def forward_common_lower(
extended_coord, fparam=fparam, aparam=aparam
)
del extended_coord, fparam, aparam
force_coord = cc_ext
if self.atomic_model.do_grad_r() or self.atomic_model.do_grad_c():
if not force_coord.requires_grad:
force_coord = force_coord.clone().requires_grad_(True)
atomic_ret = self.atomic_model.forward_common_atomic(
cc_ext,
force_coord,
extended_atype,
nlist,
mapping=mapping,
Expand All @@ -319,7 +323,7 @@ def forward_common_lower(
model_predict = fit_output_to_model_output(
atomic_ret,
self.atomic_output_def(),
cc_ext,
force_coord,
do_atomic_virial=do_atomic_virial,
create_graph=self.training,
mask=atomic_ret["mask"] if "mask" in atomic_ret else None,
Expand Down
26 changes: 26 additions & 0 deletions source/tests/pt/model/test_dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,32 @@ def test_self_consistency(self) -> None:
to_numpy_array(ret1["energy"]),
)

def test_forward_common_atomic_accepts_leaf_view_input(self) -> None:
ds = DescrptSeA(
self.rcut,
self.rcut_smth,
self.sel,
).to(env.DEVICE)
ft = InvarFitting(
"energy",
self.nt,
ds.get_dim_out(),
1,
mixed_types=ds.mixed_types(),
).to(env.DEVICE)
md0 = DPAtomicModel(ds, ft, type_map=["foo", "bar"]).to(env.DEVICE)

coord = to_torch_tensor(self.coord_ext)
coord_view = coord.view(self.nf, self.nall, 3)
args = [
coord_view,
to_torch_tensor(self.atype_ext),
to_torch_tensor(self.nlist),
]
ret = md0.forward_common_atomic(*args)

self.assertIn("energy", ret)

def test_dp_consistency(self) -> None:
nf, nloc, nnei = self.nlist.shape
ds = DPDescrptSeA(
Expand Down
36 changes: 36 additions & 0 deletions source/tests/pt/model/test_dp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,42 @@ def test_self_consistency(self) -> None:
atol=self.atol,
)

def test_forward_lower_accepts_leaf_view_input(self) -> None:
ds = DescrptSeA(
self.rcut,
self.rcut_smth,
self.sel,
).to(env.DEVICE)
ft = EnergyFittingNet(
self.nt,
ds.get_dim_out(),
mixed_types=ds.mixed_types(),
).to(env.DEVICE)
type_map = ["foo", "bar"]
md0 = EnergyModel(ds, ft, type_map=type_map).to(env.DEVICE)

coord_ext, atype_ext, _ = extend_coord_with_ghosts(
to_torch_tensor(self.coord),
to_torch_tensor(self.atype),
to_torch_tensor(self.cell),
self.rcut,
)
nlist = build_neighbor_list(
coord_ext,
atype_ext,
self.nloc,
self.rcut,
self.sel,
distinguish_types=(not md0.mixed_types()),
)
coord_view = coord_ext.view(self.nf, -1, 3)

ret = md0.forward_lower(coord_view, atype_ext, nlist, do_atomic_virial=True)

self.assertFalse(coord_view.requires_grad)
self.assertIn("extended_force", ret)
self.assertIn("virial", ret)

def test_dp_consistency(self) -> None:
nf, nloc = self.atype.shape
nfp, nap = 2, 3
Expand Down
Loading