Skip to content

Commit fe9fef8

Browse files
committed
handle atom_pref as weights in test entry point
1 parent accc331 commit fe9fef8

2 files changed

Lines changed: 118 additions & 3 deletions

File tree

deepmd/entrypoints/test.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,7 @@ def test_ener(
291291

292292
data.add("energy", 1, atomic=False, must=False, high_prec=True)
293293
data.add("force", 3, atomic=True, must=False, high_prec=False)
294+
data.add("atom_pref", 1, atomic=True, must=False, high_prec=False, repeat=3)
294295
data.add("virial", 9, atomic=False, must=False, high_prec=False)
295296
if dp.has_efield:
296297
data.add("efield", 3, atomic=True, must=True, high_prec=False)
@@ -313,6 +314,7 @@ def test_ener(
313314
find_force = test_data.get("find_force")
314315
find_virial = test_data.get("find_virial")
315316
find_force_mag = test_data.get("find_force_mag")
317+
find_atom_pref = test_data.get("find_atom_pref")
316318
mixed_type = data.mixed_type
317319
natoms = len(test_data["type"][0])
318320
nframes = test_data["box"].shape[0]
@@ -419,6 +421,18 @@ def test_ener(
419421
diff_f = force - test_data["force"][:numb_test]
420422
mae_f = mae(diff_f)
421423
rmse_f = rmse(diff_f)
424+
size_f = force.size
425+
if find_atom_pref == 1:
426+
atom_weight = test_data["atom_pref"][:numb_test]
427+
weight_sum = np.sum(atom_weight)
428+
if weight_sum > 0:
429+
mae_fw = np.sum(np.abs(diff_f) * atom_weight) / weight_sum
430+
rmse_fw = np.sqrt(
431+
np.sum(diff_f * diff_f * atom_weight) / weight_sum
432+
)
433+
else:
434+
mae_fw = 0.0
435+
rmse_fw = 0.0
422436
diff_v = virial - test_data["virial"][:numb_test]
423437
mae_v = mae(diff_v)
424438
rmse_v = rmse(diff_v)
@@ -453,8 +467,13 @@ def test_ener(
453467
if not out_put_spin and find_force == 1:
454468
log.info(f"Force MAE : {mae_f:e} eV/A")
455469
log.info(f"Force RMSE : {rmse_f:e} eV/A")
456-
dict_to_return["mae_f"] = (mae_f, force.size)
457-
dict_to_return["rmse_f"] = (rmse_f, force.size)
470+
dict_to_return["mae_f"] = (mae_f, size_f)
471+
dict_to_return["rmse_f"] = (rmse_f, size_f)
472+
if find_atom_pref == 1:
473+
log.info(f"Force weighted MAE : {mae_fw:e} eV/A")
474+
log.info(f"Force weighted RMSE: {rmse_fw:e} eV/A")
475+
dict_to_return["mae_fw"] = (mae_fw, weight_sum)
476+
dict_to_return["rmse_fw"] = (rmse_fw, weight_sum)
458477
if out_put_spin and find_force == 1:
459478
log.info(f"Force atom MAE : {mae_fr:e} eV/A")
460479
log.info(f"Force atom RMSE : {rmse_fr:e} eV/A")
@@ -600,6 +619,9 @@ def print_ener_sys_avg(avg: dict[str, float]) -> None:
600619
if "rmse_f" in avg:
601620
log.info(f"Force MAE : {avg['mae_f']:e} eV/A")
602621
log.info(f"Force RMSE : {avg['rmse_f']:e} eV/A")
622+
if "rmse_fw" in avg:
623+
log.info(f"Force weighted MAE : {avg['mae_fw']:e} eV/A")
624+
log.info(f"Force weighted RMSE: {avg['rmse_fw']:e} eV/A")
603625
else:
604626
log.info(f"Force atom MAE : {avg['mae_fr']:e} eV/A")
605627
log.info(f"Force spin MAE : {avg['mae_fm']:e} eV/uB")

source/tests/pt/test_dp_test.py

Lines changed: 94 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,15 @@
1414
import numpy as np
1515
import torch
1616

17-
from deepmd.entrypoints.test import test as dp_test
17+
from deepmd.entrypoints.test import test as dp_test, test_ener as dp_test_ener
1818
from deepmd.pt.entrypoints.main import (
1919
get_trainer,
2020
)
2121
from deepmd.pt.utils.utils import (
2222
to_numpy_array,
2323
)
24+
from deepmd.infer.deep_eval import DeepEval
25+
from deepmd.utils.data import DeepmdData
2426

2527
from .model.test_permutation import (
2628
model_property,
@@ -140,6 +142,97 @@ def setUp(self) -> None:
140142
json.dump(self.config, fp, indent=4)
141143

142144

145+
class TestDPTestForceWeight(DPTest, unittest.TestCase):
146+
def setUp(self) -> None:
147+
self.detail_file = "test_dp_test_force_weight_detail"
148+
input_json = str(Path(__file__).parent / "water/se_atten.json")
149+
with open(input_json) as f:
150+
self.config = json.load(f)
151+
self.config["training"]["numb_steps"] = 1
152+
self.config["training"]["save_freq"] = 1
153+
system_dir = self._prepare_weighted_system()
154+
data_file = [system_dir]
155+
self.config["training"]["training_data"]["systems"] = data_file
156+
self.config["training"]["validation_data"]["systems"] = data_file
157+
self.config["model"] = deepcopy(model_se_e2_a)
158+
self.system_dir = system_dir
159+
self.input_json = "test_dp_test_force_weight.json"
160+
with open(self.input_json, "w") as fp:
161+
json.dump(self.config, fp, indent=4)
162+
163+
def _prepare_weighted_system(self) -> str:
164+
src = Path(__file__).parent / "water/data/single"
165+
tmp_dir = tempfile.mkdtemp()
166+
shutil.copytree(src, tmp_dir, dirs_exist_ok=True)
167+
set_dir = Path(tmp_dir) / "set.000"
168+
forces = np.load(set_dir / "force.npy")
169+
forces[0, -3:] += 10.0
170+
np.save(set_dir / "force.npy", forces)
171+
natoms = forces.shape[1] // 3
172+
atom_pref = np.ones((forces.shape[0], natoms), dtype=forces.dtype)
173+
atom_pref[:, -1] = 0.0
174+
np.save(set_dir / "atom_pref.npy", atom_pref)
175+
return tmp_dir
176+
177+
def test_force_weight(self) -> None:
178+
trainer = get_trainer(deepcopy(self.config))
179+
with torch.device("cpu"):
180+
trainer.get_data(is_train=False)
181+
model = torch.jit.script(trainer.model)
182+
tmp_model = tempfile.NamedTemporaryFile(delete=False, suffix=".pth")
183+
torch.jit.save(model, tmp_model.name)
184+
dp = DeepEval(tmp_model.name)
185+
data = DeepmdData(
186+
self.system_dir,
187+
set_prefix="set",
188+
shuffle_test=False,
189+
type_map=dp.get_type_map(),
190+
sort_atoms=False,
191+
)
192+
err = dp_test_ener(
193+
dp,
194+
data,
195+
self.system_dir,
196+
numb_test=1,
197+
detail_file=None,
198+
has_atom_ener=False,
199+
)
200+
test_data = data.get_test()
201+
coord = test_data["coord"].reshape([1, -1])
202+
box = test_data["box"][:1]
203+
atype = test_data["type"][0]
204+
ret = dp.eval(
205+
coord,
206+
box,
207+
atype,
208+
fparam=None,
209+
aparam=None,
210+
atomic=False,
211+
efield=None,
212+
mixed_type=False,
213+
spin=None,
214+
)
215+
force_pred = ret[1].reshape([1, -1])
216+
force_true = test_data["force"][:1]
217+
weight = test_data["atom_pref"][:1]
218+
diff = force_pred - force_true
219+
diff_w = diff * weight
220+
denom = weight.sum()
221+
mae_expected = np.sum(np.abs(diff_w)) / denom
222+
rmse_expected = np.sqrt(np.sum(diff * diff * weight) / denom)
223+
mae_unweighted = np.sum(np.abs(diff)) / diff.size
224+
rmse_unweighted = np.sqrt(np.sum(diff * diff) / diff.size)
225+
np.testing.assert_allclose(err["mae_f"][0], mae_unweighted)
226+
np.testing.assert_allclose(err["rmse_f"][0], rmse_unweighted)
227+
np.testing.assert_allclose(err["mae_fw"][0], mae_expected)
228+
np.testing.assert_allclose(err["rmse_fw"][0], rmse_expected)
229+
os.unlink(tmp_model.name)
230+
231+
def tearDown(self) -> None:
232+
super().tearDown()
233+
shutil.rmtree(self.system_dir)
234+
235+
143236
class TestDPTestPropertySeA(unittest.TestCase):
144237
def setUp(self) -> None:
145238
self.detail_file = "test_dp_test_property_detail"

0 commit comments

Comments
 (0)