Skip to content

Commit e5f7ce5

Browse files
committed
feat(pt/dpmodel): add use_default_pf
1 parent 0828604 commit e5f7ce5

5 files changed

Lines changed: 361 additions & 7 deletions

File tree

deepmd/dpmodel/loss/ener.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ class EnergyLoss(Loss):
6868
The prefactor of generalized force loss at the end of the training.
6969
numb_generalized_coord : int
7070
The dimension of generalized coordinates.
71+
use_default_pf : bool
72+
If true, use default atom_pref of 1.0 for all atoms when atom_pref data is not provided.
73+
This allows using the prefactor force loss (pf) without requiring atom_pref.npy files.
7174
use_huber : bool
7275
Enables Huber loss calculation for energy/force/virial terms with user-defined threshold delta (D).
7376
The loss function smoothly transitions between L2 and L1 loss:
@@ -110,6 +113,7 @@ def __init__(
110113
huber_delta: float = 0.01,
111114
loss_func: str = "mse",
112115
f_use_norm: bool = False,
116+
use_default_pf: bool = False,
113117
**kwargs: Any,
114118
) -> None:
115119
# Validate loss_func
@@ -149,6 +153,7 @@ def __init__(
149153
self.use_huber = use_huber
150154
self.huber_delta = huber_delta
151155
self.f_use_norm = f_use_norm
156+
self.use_default_pf = use_default_pf
152157
if self.f_use_norm and not (self.use_huber or self.loss_func == "mae"):
153158
raise RuntimeError(
154159
"f_use_norm can only be True when use_huber or loss_func='mae'."
@@ -182,7 +187,9 @@ def call(
182187
find_force = label_dict["find_force"]
183188
find_virial = label_dict["find_virial"]
184189
find_atom_ener = label_dict["find_atom_ener"]
185-
find_atom_pref = label_dict["find_atom_pref"]
190+
find_atom_pref = (
191+
label_dict["find_atom_pref"] if not self.use_default_pf else 1.0
192+
)
186193
xp = array_api_compat.array_namespace(
187194
energy,
188195
force,
@@ -477,6 +484,7 @@ def label_requirement(self) -> list[DataRequirementItem]:
477484
must=False,
478485
high_prec=False,
479486
repeat=3,
487+
default=1.0,
480488
)
481489
)
482490
if self.has_gf > 0:
@@ -512,7 +520,7 @@ def serialize(self) -> dict:
512520
"""
513521
return {
514522
"@class": "EnergyLoss",
515-
"@version": 2,
523+
"@version": 3,
516524
"starter_learning_rate": self.starter_learning_rate,
517525
"start_pref_e": self.start_pref_e,
518526
"limit_pref_e": self.limit_pref_e,
@@ -533,6 +541,7 @@ def serialize(self) -> dict:
533541
"huber_delta": self.huber_delta,
534542
"loss_func": self.loss_func,
535543
"f_use_norm": self.f_use_norm,
544+
"use_default_pf": self.use_default_pf,
536545
}
537546

538547
@classmethod
@@ -550,6 +559,6 @@ def deserialize(cls, data: dict) -> "Loss":
550559
The deserialized loss module
551560
"""
552561
data = data.copy()
553-
check_version_compatibility(data.pop("@version"), 2, 1)
562+
check_version_compatibility(data.pop("@version"), 3, 1)
554563
data.pop("@class")
555564
return cls(**data)

deepmd/pt/loss/ener.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def __init__(
5656
loss_func: str = "mse",
5757
inference: bool = False,
5858
use_huber: bool = False,
59+
use_default_pf: bool = False,
5960
f_use_norm: bool = False,
6061
huber_delta: float = 0.01,
6162
**kwargs: Any,
@@ -103,6 +104,9 @@ def __init__(
103104
MAE loss is less sensitive to outliers compared to MSE loss.
104105
inference : bool
105106
If true, it will output all losses found in output, ignoring the pre-factors.
107+
use_default_pf : bool
108+
If true, use default atom_pref of 1.0 for all atoms when atom_pref data is not provided.
109+
This allows using the prefactor force loss (pf) without requiring atom_pref.npy files.
106110
use_huber : bool
107111
Enables Huber loss calculation for energy/force/virial terms with user-defined threshold delta (D).
108112
The loss function smoothly transitions between L2 and L1 loss:
@@ -147,6 +151,7 @@ def __init__(
147151
self.limit_pref_pf = limit_pref_pf
148152
self.start_pref_gf = start_pref_gf
149153
self.limit_pref_gf = limit_pref_gf
154+
self.use_default_pf = use_default_pf
150155
self.relative_f = relative_f
151156
self.enable_atom_ener_coeff = enable_atom_ener_coeff
152157
self.numb_generalized_coord = numb_generalized_coord
@@ -357,7 +362,9 @@ def forward(
357362

358363
if self.has_pf and "atom_pref" in label:
359364
atom_pref = label["atom_pref"]
360-
find_atom_pref = label.get("find_atom_pref", 0.0)
365+
find_atom_pref = (
366+
label.get("find_atom_pref", 0.0) if not self.use_default_pf else 1.0
367+
)
361368
pref_pf = pref_pf * find_atom_pref
362369
atom_pref_reshape = atom_pref.reshape(-1)
363370

@@ -514,7 +521,7 @@ def label_requirement(self) -> list[DataRequirementItem]:
514521
high_prec=True,
515522
)
516523
)
517-
if self.has_f:
524+
if self.has_f or self.has_pf or self.relative_f is not None or self.has_gf:
518525
label_requirement.append(
519526
DataRequirementItem(
520527
"force",
@@ -553,6 +560,7 @@ def label_requirement(self) -> list[DataRequirementItem]:
553560
must=False,
554561
high_prec=False,
555562
repeat=3,
563+
default=1.0,
556564
)
557565
)
558566
if self.has_gf > 0:
@@ -588,7 +596,7 @@ def serialize(self) -> dict:
588596
"""
589597
return {
590598
"@class": "EnergyLoss",
591-
"@version": 2,
599+
"@version": 3,
592600
"starter_learning_rate": self.starter_learning_rate,
593601
"start_pref_e": self.start_pref_e,
594602
"limit_pref_e": self.limit_pref_e,
@@ -609,6 +617,7 @@ def serialize(self) -> dict:
609617
"huber_delta": self.huber_delta,
610618
"loss_func": self.loss_func,
611619
"f_use_norm": self.f_use_norm,
620+
"use_default_pf": self.use_default_pf,
612621
}
613622

614623
@classmethod
@@ -626,7 +635,7 @@ def deserialize(cls, data: dict) -> "TaskLoss":
626635
The deserialized loss module
627636
"""
628637
data = data.copy()
629-
check_version_compatibility(data.pop("@version"), 2, 1)
638+
check_version_compatibility(data.pop("@version"), 3, 1)
630639
data.pop("@class")
631640
return cls(**data)
632641

deepmd/utils/argcheck.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3189,6 +3189,11 @@ def loss_ener() -> list[Argument]:
31893189
"atomic prefactor force", label="atom_pref", abbr="pf"
31903190
)
31913191
doc_limit_pref_pf = limit_pref("atomic prefactor force")
3192+
doc_use_default_pf = (
3193+
"If true, use default atom_pref of 1.0 for all atoms when atom_pref data is not provided. "
3194+
"This allows using the prefactor force loss (pf) without requiring atom_pref.npy files in training data. "
3195+
"When atom_pref.npy is provided, it will be used as-is regardless of this setting."
3196+
)
31923197
doc_start_pref_gf = start_pref("generalized force", label="drdq", abbr="gf")
31933198
doc_limit_pref_gf = limit_pref("generalized force")
31943199
doc_numb_generalized_coord = "The dimension of generalized coordinates. Required when generalized force loss is used."
@@ -3299,6 +3304,13 @@ def loss_ener() -> list[Argument]:
32993304
default=0.00,
33003305
doc=doc_limit_pref_pf,
33013306
),
3307+
Argument(
3308+
"use_default_pf",
3309+
bool,
3310+
optional=True,
3311+
default=False,
3312+
doc=doc_use_default_pf,
3313+
),
33023314
Argument("relative_f", [float, None], optional=True, doc=doc_relative_f),
33033315
Argument(
33043316
"enable_atom_ener_coeff",

doc/model/train-se-a-mask.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,22 @@ And the `loss` section in the training input script should be set as follows.
8585
}
8686
```
8787

88+
If `atom_pref.npy` is not provided in the training data, one can set `use_default_pf` to `true` to use a default atom preference of 1.0 for all atoms. This allows using the prefactor force loss (`pf` loss) without requiring `atom_pref.npy` files. When `atom_pref.npy` is provided, it will be used as-is regardless of this setting.
89+
90+
```json
91+
"loss": {
92+
"type": "ener",
93+
"start_pref_e": 0.0,
94+
"limit_pref_e": 0.0,
95+
"start_pref_f": 0.0,
96+
"limit_pref_f": 0.0,
97+
"start_pref_pf": 1.0,
98+
"limit_pref_pf": 1.0,
99+
"use_default_pf": true,
100+
"_comment": " that's all"
101+
}
102+
```
103+
88104
## Type embedding
89105

90106
Same as [`se_e2_a`](./train-se-e2-a.md).

0 commit comments

Comments
 (0)