diff --git a/deepmd/pt/model/atomic_model/dp_atomic_model.py b/deepmd/pt/model/atomic_model/dp_atomic_model.py index 783ee9e766..d59d518cab 100644 --- a/deepmd/pt/model/atomic_model/dp_atomic_model.py +++ b/deepmd/pt/model/atomic_model/dp_atomic_model.py @@ -271,8 +271,8 @@ def forward_atomic( """ nframes, nloc, nnei = nlist.shape atype = extended_atype[:, :nloc] - if self.do_grad_r() or self.do_grad_c(): - extended_coord.requires_grad_(True) + if (self.do_grad_r() or self.do_grad_c()) and not extended_coord.requires_grad: + extended_coord = extended_coord.clone().requires_grad_(True) # Handle default chg_spin if descriptor supports it if self.add_chg_spin_ebd and charge_spin is None: diff --git a/deepmd/pt/model/atomic_model/linear_atomic_model.py b/deepmd/pt/model/atomic_model/linear_atomic_model.py index 5c0f616634..41fd49f096 100644 --- a/deepmd/pt/model/atomic_model/linear_atomic_model.py +++ b/deepmd/pt/model/atomic_model/linear_atomic_model.py @@ -258,7 +258,7 @@ def forward_atomic( the result dict, defined by the fitting net output def. """ nframes, nloc, nnei = nlist.shape - if self.do_grad_r() or self.do_grad_c(): + if (self.do_grad_r() or self.do_grad_c()) and not extended_coord.requires_grad: extended_coord.requires_grad_(True) extended_coord = extended_coord.view(nframes, -1, 3) sorted_rcuts, sorted_sels = self._sort_rcuts_sels() diff --git a/deepmd/pt/model/atomic_model/pairtab_atomic_model.py b/deepmd/pt/model/atomic_model/pairtab_atomic_model.py index 5750f7cfd1..9277725c23 100644 --- a/deepmd/pt/model/atomic_model/pairtab_atomic_model.py +++ b/deepmd/pt/model/atomic_model/pairtab_atomic_model.py @@ -275,7 +275,7 @@ def forward_atomic( ) -> dict[str, torch.Tensor]: nframes, nloc, nnei = nlist.shape extended_coord = extended_coord.view(nframes, -1, 3) - if self.do_grad_r() or self.do_grad_c(): + if (self.do_grad_r() or self.do_grad_c()) and not extended_coord.requires_grad: extended_coord.requires_grad_(True) # this will mask all -1 in the nlist diff --git a/deepmd/pt/model/model/make_model.py b/deepmd/pt/model/model/make_model.py index 713eab3d8c..78705b153c 100644 --- a/deepmd/pt/model/model/make_model.py +++ b/deepmd/pt/model/model/make_model.py @@ -306,8 +306,12 @@ def forward_common_lower( extended_coord, fparam=fparam, aparam=aparam ) del extended_coord, fparam, aparam + force_coord = cc_ext + if self.atomic_model.do_grad_r() or self.atomic_model.do_grad_c(): + if not force_coord.requires_grad: + force_coord = force_coord.clone().requires_grad_(True) atomic_ret = self.atomic_model.forward_common_atomic( - cc_ext, + force_coord, extended_atype, nlist, mapping=mapping, @@ -319,7 +323,7 @@ def forward_common_lower( model_predict = fit_output_to_model_output( atomic_ret, self.atomic_output_def(), - cc_ext, + force_coord, do_atomic_virial=do_atomic_virial, create_graph=self.training, mask=atomic_ret["mask"] if "mask" in atomic_ret else None, diff --git a/source/tests/pt/model/test_dp_atomic_model.py b/source/tests/pt/model/test_dp_atomic_model.py index 6d6a22f357..ddc8f6d5ed 100644 --- a/source/tests/pt/model/test_dp_atomic_model.py +++ b/source/tests/pt/model/test_dp_atomic_model.py @@ -74,6 +74,32 @@ def test_self_consistency(self) -> None: to_numpy_array(ret1["energy"]), ) + def test_forward_common_atomic_accepts_leaf_view_input(self) -> None: + ds = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + ).to(env.DEVICE) + ft = InvarFitting( + "energy", + self.nt, + ds.get_dim_out(), + 1, + mixed_types=ds.mixed_types(), + ).to(env.DEVICE) + md0 = DPAtomicModel(ds, ft, type_map=["foo", "bar"]).to(env.DEVICE) + + coord = to_torch_tensor(self.coord_ext) + coord_view = coord.view(self.nf, self.nall, 3) + args = [ + coord_view, + to_torch_tensor(self.atype_ext), + to_torch_tensor(self.nlist), + ] + ret = md0.forward_common_atomic(*args) + + self.assertIn("energy", ret) + def test_dp_consistency(self) -> None: nf, nloc, nnei = self.nlist.shape ds = DPDescrptSeA( diff --git a/source/tests/pt/model/test_dp_model.py b/source/tests/pt/model/test_dp_model.py index f4e350869a..37f0fb9caa 100644 --- a/source/tests/pt/model/test_dp_model.py +++ b/source/tests/pt/model/test_dp_model.py @@ -114,6 +114,42 @@ def test_self_consistency(self) -> None: atol=self.atol, ) + def test_forward_lower_accepts_leaf_view_input(self) -> None: + ds = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + ).to(env.DEVICE) + ft = EnergyFittingNet( + self.nt, + ds.get_dim_out(), + mixed_types=ds.mixed_types(), + ).to(env.DEVICE) + type_map = ["foo", "bar"] + md0 = EnergyModel(ds, ft, type_map=type_map).to(env.DEVICE) + + coord_ext, atype_ext, _ = extend_coord_with_ghosts( + to_torch_tensor(self.coord), + to_torch_tensor(self.atype), + to_torch_tensor(self.cell), + self.rcut, + ) + nlist = build_neighbor_list( + coord_ext, + atype_ext, + self.nloc, + self.rcut, + self.sel, + distinguish_types=(not md0.mixed_types()), + ) + coord_view = coord_ext.view(self.nf, -1, 3) + + ret = md0.forward_lower(coord_view, atype_ext, nlist, do_atomic_virial=True) + + self.assertFalse(coord_view.requires_grad) + self.assertIn("extended_force", ret) + self.assertIn("virial", ret) + def test_dp_consistency(self) -> None: nf, nloc = self.atype.shape nfp, nap = 2, 3