Skip to content

Commit 91b9e68

Browse files
committed
add extra_fact
1 parent 7c34d93 commit 91b9e68

2 files changed

Lines changed: 18 additions & 2 deletions

File tree

deepmd/pt/model/task/ener.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -536,6 +536,8 @@ def __init__(
536536
add_angle_readout: bool = False,
537537
slim_edge_readout: bool = False,
538538
slim_angle_readout: bool = False,
539+
edge_extra_fact: float = 1.0,
540+
angle_extra_fact: float = 1.0,
539541
**kwargs,
540542
) -> None:
541543
"""Construct a fitting net for energy.
@@ -549,6 +551,8 @@ def __init__(
549551
"""
550552
self.add_edge_readout = add_edge_readout
551553
self.add_angle_readout = add_angle_readout
554+
self.edge_extra_fact = edge_extra_fact
555+
self.angle_extra_fact = angle_extra_fact
552556
super().__init__(
553557
"energy",
554558
ntypes,
@@ -714,7 +718,7 @@ def forward(
714718
# nf x nloc x 1
715719
edge_energy = torch.sum(edge_atomic_contrib, dim=-2)
716720
# energy
717-
out = out + edge_energy / self.norm_e_fact
721+
out = out + (edge_energy * self.edge_extra_fact) / self.norm_e_fact
718722

719723
if self.add_angle_readout:
720724
assert angle_embd is not None
@@ -747,5 +751,5 @@ def forward(
747751
)
748752
# energy
749753
# self.norm_a_fact ** 2
750-
out = out + angle_energy / (self.norm_a_fact**2)
754+
out = out + (angle_energy * self.angle_extra_fact) / (self.norm_a_fact**2)
751755
return {self.var_name: out.to(env.GLOBAL_PT_FLOAT_PRECISION)}

deepmd/utils/argcheck.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2220,6 +2220,18 @@ def fitting_ener_readout():
22202220
optional=True,
22212221
default=False,
22222222
),
2223+
Argument(
2224+
"edge_extra_fact",
2225+
float,
2226+
optional=True,
2227+
default=1.0,
2228+
),
2229+
Argument(
2230+
"angle_extra_fact",
2231+
float,
2232+
optional=True,
2233+
default=1.0,
2234+
),
22232235
]
22242236

22252237

0 commit comments

Comments
 (0)