Skip to content

Commit 66fc99e

Browse files
committed
Merge branch 'codex/update-test_ener-to-handle-masked-forces-tkzu2y' into codex/update-test_ener-to-handle-masked-forces
2 parents 13d43e2 + fe9fef8 commit 66fc99e

2 files changed

Lines changed: 41 additions & 33 deletions

File tree

deepmd/entrypoints/test.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -419,20 +419,20 @@ def test_ener(
419419
mae_e = mae(diff_e)
420420
rmse_e = rmse(diff_e)
421421
diff_f = force - test_data["force"][:numb_test]
422+
mae_f = mae(diff_f)
423+
rmse_f = rmse(diff_f)
424+
size_f = force.size
422425
if find_atom_pref == 1:
423-
atom_pref = test_data["atom_pref"][:numb_test]
424-
diff_f = diff_f * atom_pref
425-
size_f = np.sum(atom_pref)
426-
if size_f > 0:
427-
mae_f = np.sum(np.abs(diff_f)) / size_f
428-
rmse_f = np.sqrt(np.sum(diff_f * diff_f) / size_f)
426+
atom_weight = test_data["atom_pref"][:numb_test]
427+
weight_sum = np.sum(atom_weight)
428+
if weight_sum > 0:
429+
mae_fw = np.sum(np.abs(diff_f) * atom_weight) / weight_sum
430+
rmse_fw = np.sqrt(
431+
np.sum(diff_f * diff_f * atom_weight) / weight_sum
432+
)
429433
else:
430-
mae_f = 0.0
431-
rmse_f = 0.0
432-
else:
433-
mae_f = mae(diff_f)
434-
rmse_f = rmse(diff_f)
435-
size_f = force.size
434+
mae_fw = 0.0
435+
rmse_fw = 0.0
436436
diff_v = virial - test_data["virial"][:numb_test]
437437
mae_v = mae(diff_v)
438438
rmse_v = rmse(diff_v)
@@ -469,6 +469,11 @@ def test_ener(
469469
log.info(f"Force RMSE : {rmse_f:e} eV/A")
470470
dict_to_return["mae_f"] = (mae_f, size_f)
471471
dict_to_return["rmse_f"] = (rmse_f, size_f)
472+
if find_atom_pref == 1:
473+
log.info(f"Force weighted MAE : {mae_fw:e} eV/A")
474+
log.info(f"Force weighted RMSE: {rmse_fw:e} eV/A")
475+
dict_to_return["mae_fw"] = (mae_fw, weight_sum)
476+
dict_to_return["rmse_fw"] = (rmse_fw, weight_sum)
472477
if out_put_spin and find_force == 1:
473478
log.info(f"Force atom MAE : {mae_fr:e} eV/A")
474479
log.info(f"Force atom RMSE : {rmse_fr:e} eV/A")
@@ -614,6 +619,9 @@ def print_ener_sys_avg(avg: dict[str, float]) -> None:
614619
if "rmse_f" in avg:
615620
log.info(f"Force MAE : {avg['mae_f']:e} eV/A")
616621
log.info(f"Force RMSE : {avg['rmse_f']:e} eV/A")
622+
if "rmse_fw" in avg:
623+
log.info(f"Force weighted MAE : {avg['mae_fw']:e} eV/A")
624+
log.info(f"Force weighted RMSE: {avg['rmse_fw']:e} eV/A")
617625
else:
618626
log.info(f"Force atom MAE : {avg['mae_fr']:e} eV/A")
619627
log.info(f"Force spin MAE : {avg['mae_fm']:e} eV/uB")

source/tests/pt/test_dp_test.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,15 @@
1414
import numpy as np
1515
import torch
1616

17-
from deepmd.entrypoints.test import test as dp_test
18-
from deepmd.entrypoints.test import test_ener as dp_test_ener
19-
from deepmd.infer.deep_eval import (
20-
DeepEval,
21-
)
17+
from deepmd.entrypoints.test import test as dp_test, test_ener as dp_test_ener
2218
from deepmd.pt.entrypoints.main import (
2319
get_trainer,
2420
)
2521
from deepmd.pt.utils.utils import (
2622
to_numpy_array,
2723
)
28-
from deepmd.utils.data import (
29-
DeepmdData,
30-
)
24+
from deepmd.infer.deep_eval import DeepEval
25+
from deepmd.utils.data import DeepmdData
3126

3227
from .model.test_permutation import (
3328
model_property,
@@ -147,25 +142,25 @@ def setUp(self) -> None:
147142
json.dump(self.config, fp, indent=4)
148143

149144

150-
class TestDPTestForceMask(DPTest, unittest.TestCase):
145+
class TestDPTestForceWeight(DPTest, unittest.TestCase):
151146
def setUp(self) -> None:
152-
self.detail_file = "test_dp_test_force_mask_detail"
147+
self.detail_file = "test_dp_test_force_weight_detail"
153148
input_json = str(Path(__file__).parent / "water/se_atten.json")
154149
with open(input_json) as f:
155150
self.config = json.load(f)
156151
self.config["training"]["numb_steps"] = 1
157152
self.config["training"]["save_freq"] = 1
158-
system_dir = self._prepare_masked_system()
153+
system_dir = self._prepare_weighted_system()
159154
data_file = [system_dir]
160155
self.config["training"]["training_data"]["systems"] = data_file
161156
self.config["training"]["validation_data"]["systems"] = data_file
162157
self.config["model"] = deepcopy(model_se_e2_a)
163158
self.system_dir = system_dir
164-
self.input_json = "test_dp_test_force_mask.json"
159+
self.input_json = "test_dp_test_force_weight.json"
165160
with open(self.input_json, "w") as fp:
166161
json.dump(self.config, fp, indent=4)
167162

168-
def _prepare_masked_system(self) -> str:
163+
def _prepare_weighted_system(self) -> str:
169164
src = Path(__file__).parent / "water/data/single"
170165
tmp_dir = tempfile.mkdtemp()
171166
shutil.copytree(src, tmp_dir, dirs_exist_ok=True)
@@ -179,7 +174,7 @@ def _prepare_masked_system(self) -> str:
179174
np.save(set_dir / "atom_pref.npy", atom_pref)
180175
return tmp_dir
181176

182-
def test_force_mask(self) -> None:
177+
def test_force_weight(self) -> None:
183178
trainer = get_trainer(deepcopy(self.config))
184179
with torch.device("cpu"):
185180
trainer.get_data(is_train=False)
@@ -219,13 +214,18 @@ def test_force_mask(self) -> None:
219214
)
220215
force_pred = ret[1].reshape([1, -1])
221216
force_true = test_data["force"][:1]
222-
mask = test_data["atom_pref"][:1]
223-
diff = (force_pred - force_true) * mask
224-
denom = mask.sum()
225-
mae_expected = np.sum(np.abs(diff)) / denom
226-
rmse_expected = np.sqrt(np.sum(diff * diff) / denom)
227-
np.testing.assert_allclose(err["mae_f"][0], mae_expected)
228-
np.testing.assert_allclose(err["rmse_f"][0], rmse_expected)
217+
weight = test_data["atom_pref"][:1]
218+
diff = force_pred - force_true
219+
diff_w = diff * weight
220+
denom = weight.sum()
221+
mae_expected = np.sum(np.abs(diff_w)) / denom
222+
rmse_expected = np.sqrt(np.sum(diff * diff * weight) / denom)
223+
mae_unweighted = np.sum(np.abs(diff)) / diff.size
224+
rmse_unweighted = np.sqrt(np.sum(diff * diff) / diff.size)
225+
np.testing.assert_allclose(err["mae_f"][0], mae_unweighted)
226+
np.testing.assert_allclose(err["rmse_f"][0], rmse_unweighted)
227+
np.testing.assert_allclose(err["mae_fw"][0], mae_expected)
228+
np.testing.assert_allclose(err["rmse_fw"][0], rmse_expected)
229229
os.unlink(tmp_model.name)
230230

231231
def tearDown(self) -> None:

0 commit comments

Comments
 (0)