Skip to content

Commit 0f25a34

Browse files
committed
minor code improvement based on @coderabbitai
1 parent 9ff8872 commit 0f25a34

File tree

2 files changed

+13
-12
lines changed

2 files changed

+13
-12
lines changed

deepmd/pt/modifier/dipole_charge.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,13 +129,14 @@ def forward(
129129
Atom parameters with shape (nframes, natoms, nap), by default None
130130
do_atomic_virial : bool, optional
131131
Whether to compute atomic virial, by default False
132+
Note: This parameter is currently not implemented and is ignored
132133
133134
Returns
134135
-------
135136
dict[str, torch.Tensor]
136137
Dictionary containing the correction terms:
137138
- energy: Energy correction tensor with shape (nframes, 1)
138-
- force: Force correction tensor with shape (nframes, natoms+nsel, 3)
139+
- force: Force correction tensor with shape (nframes, natoms, 3)
139140
- virial: Virial correction tensor with shape (nframes, 3, 3)
140141
"""
141142
if box is None:
@@ -305,7 +306,10 @@ def extend_system_coord(
305306

306307
# nframe x natoms x 3
307308
dipole = torch.cat(all_dipole, dim=0)
308-
assert dipole.shape[0] == nframes
309+
if dipole.shape[0] != nframes:
310+
raise RuntimeError(
311+
f"Dipole shape mismatch: expected {nframes} frames, got {dipole.shape[0]}"
312+
)
309313

310314
dipole_reshaped = dipole.reshape(nframes, natoms, 3)
311315
coord_reshaped = coord.reshape(nframes, natoms, 3)

source/tests/pt/modifier/test_dipole_charge.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22
import json
33
import os
4+
import tempfile
45
import unittest
56
from pathlib import (
67
Path,
@@ -50,6 +51,9 @@ def ref_data():
5051

5152
class TestDipoleChargeModifier(unittest.TestCase):
5253
def setUp(self) -> None:
54+
self.test_dir = tempfile.TemporaryDirectory()
55+
self.orig_dir = os.getcwd()
56+
os.chdir(self.test_dir.name)
5357
# setup parameter
5458
# numerical consistency can only be achieved with high prec
5559
self.ewald_h = 0.1
@@ -170,7 +174,7 @@ def test_serialize(self):
170174
to_numpy_array(ret0["virial"]), to_numpy_array(ret1["virial"])
171175
)
172176

173-
def test_box_none_warning(self):
177+
def test_box_none_error(self):
174178
"""Test that a RuntimeError is raised when box is None."""
175179
coord, _b, atype = ref_data()
176180
# consistent with the input shape from BaseModifier.modify_data
@@ -209,12 +213,5 @@ def test_train(self):
209213
trainer.run()
210214

211215
def tearDown(self) -> None:
212-
for f in os.listdir("."):
213-
if f.startswith("frozen_model") and f.endswith(".pth"):
214-
os.remove(f)
215-
if f.startswith("dw_model") and (f.endswith(".pth") or f.endswith(".pb")):
216-
os.remove(f)
217-
if f.startswith("model.ckpt") and f.endswith(".pt"):
218-
os.remove(f)
219-
if f in ["lcurve.out", "checkpoint"]:
220-
os.remove(f)
216+
os.chdir(self.orig_dir)
217+
self.test_dir.cleanup()

0 commit comments

Comments
 (0)