Skip to content

Commit 604fff9

Browse files
committed
fix comment
1 parent 4c61898 commit 604fff9

3 files changed

Lines changed: 184 additions & 4 deletions

File tree

deepmd/utils/argcheck.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3220,7 +3220,9 @@ def loss_ener() -> list[Argument]:
32203220
doc_use_default_pf = (
32213221
"If true, use default atom_pref of 1.0 for all atoms when atom_pref data is not provided. "
32223222
"This allows using the prefactor force loss (pf) without requiring atom_pref.npy files in training data. "
3223-
"When atom_pref.npy is provided, it will be used as-is regardless of this setting."
3223+
"When atom_pref.npy is provided, it will be used as-is regardless of this setting. "
3224+
"Note: this option is only effective for the PyTorch/DPModel backends; "
3225+
"the TensorFlow and Paddle backends raise NotImplementedError when set to true."
32243226
)
32253227
doc_start_pref_gf = start_pref("generalized force", label="drdq", abbr="gf")
32263228
doc_limit_pref_gf = limit_pref("generalized force")

source/tests/consistent/loss/test_ener.py

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -812,3 +812,181 @@ def test_intensive_vs_legacy_scaling_difference(self) -> None:
812812
places=5,
813813
msg=f"Expected intensive/legacy ratio ~{expected_ratio:.6f}, got {actual_ratio:.6f}",
814814
)
815+
816+
817+
class TestEnerDefaultPf(CommonTest, LossTest, unittest.TestCase):
818+
"""Test energy loss with use_default_pf=True.
819+
820+
The pf term is activated through the default atom_pref of 1.0 even though
821+
`find_atom_pref` is 0.0 in the label. This exercises the cross-backend
822+
consistency between PT and DP for the new option. TF and Paddle backends
823+
raise NotImplementedError when use_default_pf=True and are skipped.
824+
"""
825+
826+
@property
827+
def data(self) -> dict:
828+
return {
829+
"start_pref_e": 0.02,
830+
"limit_pref_e": 1.0,
831+
"start_pref_f": 1000.0,
832+
"limit_pref_f": 1.0,
833+
"start_pref_v": 1.0,
834+
"limit_pref_v": 1.0,
835+
"start_pref_ae": 1.0,
836+
"limit_pref_ae": 1.0,
837+
"start_pref_pf": 1.0,
838+
"limit_pref_pf": 1.0,
839+
"use_default_pf": True,
840+
}
841+
842+
skip_tf = True
843+
skip_pd = True
844+
skip_pt = CommonTest.skip_pt
845+
skip_pt_expt = not INSTALLED_PT_EXPT
846+
skip_jax = not INSTALLED_JAX
847+
skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT
848+
849+
tf_class = EnerLossTF
850+
dp_class = EnerLossDP
851+
pt_class = EnerLossPT
852+
pt_expt_class = EnerLossPTExpt
853+
jax_class = EnerLossDP
854+
pd_class = EnerLossPD
855+
array_api_strict_class = EnerLossDP
856+
args = loss_ener()
857+
858+
def setUp(self) -> None:
859+
CommonTest.setUp(self)
860+
self.learning_rate = 1e-3
861+
rng = np.random.default_rng(20250105)
862+
self.nframes = 2
863+
self.natoms = 6
864+
self.predict = {
865+
"energy": rng.random((self.nframes,)),
866+
"force": rng.random((self.nframes, self.natoms, 3)),
867+
"virial": rng.random((self.nframes, 9)),
868+
"atom_ener": rng.random((self.nframes, self.natoms)),
869+
}
870+
self.predict_dpmodel_style = {
871+
"energy": self.predict["energy"],
872+
"force": self.predict["force"],
873+
"virial": self.predict["virial"],
874+
"atom_energy": self.predict["atom_ener"],
875+
}
876+
# find_atom_pref=0.0 simulates the case where atom_pref.npy is missing;
877+
# use_default_pf=True must override this and still compute the pf loss.
878+
self.label = {
879+
"energy": rng.random((self.nframes,)),
880+
"force": rng.random((self.nframes, self.natoms, 3)),
881+
"virial": rng.random((self.nframes, 9)),
882+
"atom_ener": rng.random((self.nframes, self.natoms)),
883+
"atom_pref": np.ones((self.nframes, self.natoms, 3)),
884+
"find_energy": 1.0,
885+
"find_force": 1.0,
886+
"find_virial": 1.0,
887+
"find_atom_ener": 1.0,
888+
"find_atom_pref": 0.0,
889+
}
890+
891+
@property
892+
def additional_data(self) -> dict:
893+
return {
894+
"starter_learning_rate": 1e-3,
895+
}
896+
897+
def build_tf(self, obj: Any, suffix: str) -> tuple[list, dict]:
898+
# use_default_pf=True is not supported by TensorFlow; skip_tf is True so
899+
# this method is never invoked, but the abstract base requires it.
900+
raise NotImplementedError
901+
902+
def eval_pt(self, pt_obj: Any) -> Any:
903+
predict = {kk: numpy_to_torch(vv) for kk, vv in self.predict.items()}
904+
label = {kk: numpy_to_torch(vv) for kk, vv in self.label.items()}
905+
predict["atom_energy"] = predict.pop("atom_ener")
906+
_, loss, more_loss = pt_obj(
907+
{},
908+
lambda: predict,
909+
label,
910+
self.natoms,
911+
self.learning_rate,
912+
mae=False,
913+
)
914+
loss = torch_to_numpy(loss)
915+
more_loss = {kk: torch_to_numpy(vv) for kk, vv in more_loss.items()}
916+
return loss, more_loss
917+
918+
def eval_dp(self, dp_obj: Any) -> Any:
919+
return dp_obj(
920+
self.learning_rate,
921+
self.natoms,
922+
self.predict_dpmodel_style,
923+
self.label,
924+
mae=False,
925+
)
926+
927+
def eval_pt_expt(self, pt_expt_obj: Any) -> Any:
928+
predict = {
929+
kk: numpy_to_torch(vv) for kk, vv in self.predict_dpmodel_style.items()
930+
}
931+
label = {kk: numpy_to_torch(vv) for kk, vv in self.label.items()}
932+
loss, more_loss = pt_expt_obj(
933+
self.learning_rate,
934+
self.natoms,
935+
predict,
936+
label,
937+
mae=False,
938+
)
939+
loss = torch_to_numpy(loss)
940+
more_loss = {kk: torch_to_numpy(vv) for kk, vv in more_loss.items()}
941+
return loss, more_loss
942+
943+
def eval_jax(self, jax_obj: Any) -> Any:
944+
predict = {kk: jnp.asarray(vv) for kk, vv in self.predict_dpmodel_style.items()}
945+
label = {kk: jnp.asarray(vv) for kk, vv in self.label.items()}
946+
loss, more_loss = jax_obj(
947+
self.learning_rate,
948+
self.natoms,
949+
predict,
950+
label,
951+
mae=False,
952+
)
953+
loss = to_numpy_array(loss)
954+
more_loss = {kk: to_numpy_array(vv) for kk, vv in more_loss.items()}
955+
return loss, more_loss
956+
957+
def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any:
958+
predict = {
959+
kk: array_api_strict.asarray(vv)
960+
for kk, vv in self.predict_dpmodel_style.items()
961+
}
962+
label = {kk: array_api_strict.asarray(vv) for kk, vv in self.label.items()}
963+
loss, more_loss = array_api_strict_obj(
964+
self.learning_rate,
965+
self.natoms,
966+
predict,
967+
label,
968+
mae=False,
969+
)
970+
loss = to_numpy_array(loss)
971+
more_loss = {kk: to_numpy_array(vv) for kk, vv in more_loss.items()}
972+
return loss, more_loss
973+
974+
def extract_ret(self, ret: Any, backend) -> dict[str, np.ndarray]:
975+
loss = ret[0]
976+
result = {"loss": np.atleast_1d(np.asarray(loss, dtype=np.float64))}
977+
if len(ret) > 1:
978+
more_loss = ret[1]
979+
for k in sorted(more_loss):
980+
if k.startswith("rmse_") or k.startswith("mae_"):
981+
result[k] = np.atleast_1d(
982+
np.asarray(more_loss[k], dtype=np.float64)
983+
)
984+
return result
985+
986+
@property
987+
def rtol(self) -> float:
988+
return 1e-10
989+
990+
@property
991+
def atol(self) -> float:
992+
return 1e-10

source/tests/pt/test_loss_default_pf.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def fake_model():
171171
)
172172
pt_loss_val = pt_loss.detach().cpu().numpy()
173173
# loss should be non-zero because pf loss is activated via use_default_pf
174-
self.assertTrue(pt_loss_val != 0.0)
174+
self.assertNotEqual(float(pt_loss_val), 0.0)
175175
self.assertIn("rmse_pf", pt_more_loss)
176176
# The pref_force_loss should be a valid number (not NaN)
177177
self.assertFalse(np.isnan(pt_more_loss["l2_pref_force_loss"]))
@@ -195,15 +195,15 @@ def fake_model():
195195
return self.model_pred
196196

197197
# With find_atom_pref=0.0 and use_default_pf=False, pf loss contribution is zero
198-
_, pt_loss_without, pt_more_loss_without = loss_fn(
198+
_, _pt_loss_without, pt_more_loss_without = loss_fn(
199199
{},
200200
fake_model,
201201
self.label_without_pref,
202202
self.nloc,
203203
self.cur_lr,
204204
)
205205
# With find_atom_pref=1.0, pf loss should be computed
206-
_, pt_loss_with, pt_more_loss_with = loss_fn(
206+
_, _pt_loss_with, pt_more_loss_with = loss_fn(
207207
{},
208208
fake_model,
209209
self.label_with_pref,

0 commit comments

Comments
 (0)