Skip to content

Commit b23498b

Browse files
committed
fix ut
1 parent 3341fd1 commit b23498b

1 file changed

Lines changed: 10 additions & 0 deletions

File tree

deepmd/tf/loss/ener.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,10 @@ def __init__(
123123
f_use_norm: bool = False,
124124
**kwargs: Any,
125125
) -> None:
126+
if loss_func != "mse":
127+
raise NotImplementedError(
128+
f"TensorFlow backend only supports loss_func='mse', got '{loss_func}'."
129+
)
126130
self.loss_func = loss_func
127131
self.f_use_norm = f_use_norm
128132

@@ -588,7 +592,13 @@ def __init__(
588592
relative_f: float | None = None,
589593
enable_atom_ener_coeff: bool = False,
590594
use_spin: list | None = None,
595+
loss_func: str = "mse",
591596
) -> None:
597+
if loss_func != "mse":
598+
raise NotImplementedError(
599+
f"TensorFlow backend only supports loss_func='mse', got '{loss_func}'."
600+
)
601+
self.loss_func = loss_func
592602
self.starter_learning_rate = starter_learning_rate
593603
self.start_pref_e = start_pref_e
594604
self.limit_pref_e = limit_pref_e

0 commit comments

Comments
 (0)