Skip to content

Commit 7d8a307

Browse files
committed
fix ut
1 parent 28e6fdd commit 7d8a307

2 files changed

Lines changed: 12 additions & 2 deletions

File tree

deepmd/pd/loss/ener.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,10 @@ def __init__(
125125
raise NotImplementedError(
126126
"Paddle backend does not support f_use_norm=True."
127127
)
128+
if kwargs.get("use_default_pf", False):
129+
raise NotImplementedError(
130+
"Paddle backend does not support use_default_pf=True."
131+
)
128132

129133
self.starter_learning_rate = starter_learning_rate
130134
self.has_e = (start_pref_e != 0.0 and limit_pref_e != 0.0) or inference
@@ -554,7 +558,7 @@ def serialize(self) -> dict:
554558
"""
555559
return {
556560
"@class": "EnergyLoss",
557-
"@version": 2,
561+
"@version": 3,
558562
"starter_learning_rate": self.starter_learning_rate,
559563
"start_pref_e": self.start_pref_e,
560564
"limit_pref_e": self.limit_pref_e,
@@ -575,6 +579,7 @@ def serialize(self) -> dict:
575579
"huber_delta": self.huber_delta,
576580
"loss_func": self.loss_func,
577581
"f_use_norm": self.f_use_norm,
582+
"use_default_pf": getattr(self, "use_default_pf", False),
578583
}
579584

580585
@classmethod

deepmd/tf/loss/ener.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,10 @@ def __init__(
133133
raise NotImplementedError(
134134
"TensorFlow backend does not support f_use_norm=True."
135135
)
136+
if kwargs.get("use_default_pf", False):
137+
raise NotImplementedError(
138+
"TensorFlow backend does not support use_default_pf=True."
139+
)
136140

137141
self.starter_learning_rate = starter_learning_rate
138142
self.start_pref_e = start_pref_e
@@ -531,7 +535,7 @@ def serialize(self, suffix: str = "") -> dict:
531535
"""
532536
return {
533537
"@class": "EnergyLoss",
534-
"@version": 2,
538+
"@version": 3,
535539
"starter_learning_rate": self.starter_learning_rate,
536540
"start_pref_e": self.start_pref_e,
537541
"limit_pref_e": self.limit_pref_e,
@@ -552,6 +556,7 @@ def serialize(self, suffix: str = "") -> dict:
552556
"huber_delta": self.huber_delta,
553557
"loss_func": self.loss_func,
554558
"f_use_norm": self.f_use_norm,
559+
"use_default_pf": getattr(self, "use_default_pf", False),
555560
}
556561

557562
@classmethod

0 commit comments

Comments
 (0)