Skip to content

Commit e95f0d6

Browse files
author
Han Wang
committed
refactor(pt_expt): re-export loss classes from dpmodel instead of wrapping
Loss classes hold only plain Python floats/bools (no tensors, no NativeOP sub-components), so @torch_module wrapping is unnecessary. Re-export dpmodel classes directly. Remove .to(device) calls in tests and relax float32 tolerance in spin loss test.
1 parent fb2027b commit e95f0d6

10 files changed

Lines changed: 28 additions & 46 deletions

File tree

deepmd/pt_expt/loss/dos.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,6 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2-
from deepmd.dpmodel.loss.dos import DOSLoss as DOSLossDP
3-
from deepmd.pt_expt.common import (
4-
torch_module,
2+
from deepmd.dpmodel.loss.dos import (
3+
DOSLoss,
54
)
65

7-
8-
@torch_module
9-
class DOSLoss(DOSLossDP):
10-
pass
6+
__all__ = ["DOSLoss"]

deepmd/pt_expt/loss/ener.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,6 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2-
from deepmd.dpmodel.loss.ener import EnergyLoss as EnergyLossDP
3-
from deepmd.pt_expt.common import (
4-
torch_module,
2+
from deepmd.dpmodel.loss.ener import (
3+
EnergyLoss,
54
)
65

7-
8-
@torch_module
9-
class EnergyLoss(EnergyLossDP):
10-
pass
6+
__all__ = ["EnergyLoss"]

deepmd/pt_expt/loss/ener_spin.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,6 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2-
from deepmd.dpmodel.loss.ener_spin import EnergySpinLoss as EnergySpinLossDP
3-
from deepmd.pt_expt.common import (
4-
torch_module,
2+
from deepmd.dpmodel.loss.ener_spin import (
3+
EnergySpinLoss,
54
)
65

7-
8-
@torch_module
9-
class EnergySpinLoss(EnergySpinLossDP):
10-
pass
6+
__all__ = ["EnergySpinLoss"]

deepmd/pt_expt/loss/property.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,6 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2-
from deepmd.dpmodel.loss.property import PropertyLoss as PropertyLossDP
3-
from deepmd.pt_expt.common import (
4-
torch_module,
2+
from deepmd.dpmodel.loss.property import (
3+
PropertyLoss,
54
)
65

7-
8-
@torch_module
9-
class PropertyLoss(PropertyLossDP):
10-
pass
6+
__all__ = ["PropertyLoss"]

deepmd/pt_expt/loss/tensor.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,6 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2-
from deepmd.dpmodel.loss.tensor import TensorLoss as TensorLossDP
3-
from deepmd.pt_expt.common import (
4-
torch_module,
2+
from deepmd.dpmodel.loss.tensor import (
3+
TensorLoss,
54
)
65

7-
8-
@torch_module
9-
class TensorLoss(TensorLossDP):
10-
pass
6+
__all__ = ["TensorLoss"]

source/tests/pt_expt/loss/test_dos.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def test_consistency(self, prec, has_dos, has_ados) -> None:
8585
limit_pref_ados=0.5 if has_ados else 0.0,
8686
start_pref_acdf=0.0,
8787
limit_pref_acdf=0.0,
88-
).to(self.device)
88+
)
8989

9090
model_pred, label = _make_data(
9191
rng, nframes, natoms, numb_dos, dtype, self.device
@@ -152,7 +152,7 @@ def test_cdf_terms(self, prec) -> None:
152152
limit_pref_ados=0.0,
153153
start_pref_acdf=1.0,
154154
limit_pref_acdf=0.5,
155-
).to(self.device)
155+
)
156156

157157
model_pred, label = _make_data(
158158
rng, nframes, natoms, numb_dos, dtype, self.device

source/tests/pt_expt/loss/test_ener.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def test_consistency(self, prec, use_huber) -> None:
9898
start_pref_pf=0.0 if use_huber else 1.0,
9999
limit_pref_pf=0.0 if use_huber else 1.0,
100100
use_huber=use_huber,
101-
).to(self.device)
101+
)
102102

103103
model_pred, label = _make_data(rng, nframes, natoms, dtype, self.device)
104104

source/tests/pt_expt/loss/test_ener_spin.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,8 @@ def test_consistency(self, prec, loss_func) -> None:
8686
nframes, natoms, n_magnetic = 2, 6, 4
8787
dtype = PRECISION_DICT[prec]
8888
rtol, atol = get_tols(prec)
89+
if prec in ["single", "float32"]:
90+
atol = max(atol, 2e-4) # relax for float32 rounding across envs
8991
learning_rate = 1e-3
9092

9193
loss0 = EnergySpinLoss(
@@ -101,7 +103,7 @@ def test_consistency(self, prec, loss_func) -> None:
101103
start_pref_ae=1.0,
102104
limit_pref_ae=1.0,
103105
loss_func=loss_func,
104-
).to(self.device)
106+
)
105107

106108
model_pred, label = _make_data(
107109
rng, nframes, natoms, n_magnetic, dtype, self.device
@@ -172,7 +174,7 @@ def test_partial_mask(self, prec) -> None:
172174
limit_pref_v=0.0,
173175
start_pref_ae=0.0,
174176
limit_pref_ae=0.0,
175-
).to(self.device)
177+
)
176178

177179
model_pred, label = _make_data(
178180
rng, nframes, natoms, n_magnetic, dtype, self.device
@@ -223,7 +225,7 @@ def test_all_masked(self, prec) -> None:
223225
limit_pref_v=0.0,
224226
start_pref_ae=0.0,
225227
limit_pref_ae=0.0,
226-
).to(self.device)
228+
)
227229

228230
model_pred, label = _make_data(
229231
rng, nframes, natoms, n_magnetic, dtype, self.device

source/tests/pt_expt/loss/test_property.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def test_consistency(self, prec, loss_func) -> None:
7777
out_bias=[0.1, 0.5, 1.2, -0.1, -10.0],
7878
out_std=[8.0, 10.0, 0.001, -0.2, -10.0],
7979
intensive=False,
80-
).to(self.device)
80+
)
8181

8282
model_pred, label = _make_data(
8383
rng, nframes, task_dim, var_name, dtype, self.device
@@ -143,7 +143,7 @@ def test_intensive(self, prec) -> None:
143143
out_bias=None,
144144
out_std=None,
145145
intensive=True,
146-
).to(self.device)
146+
)
147147

148148
model_pred, label = _make_data(
149149
rng, nframes, task_dim, var_name, dtype, self.device
@@ -187,7 +187,7 @@ def test_no_out_bias_std(self, prec) -> None:
187187
out_bias=None,
188188
out_std=None,
189189
intensive=False,
190-
).to(self.device)
190+
)
191191

192192
model_pred, label = _make_data(
193193
rng, nframes, task_dim, var_name, dtype, self.device

source/tests/pt_expt/loss/test_tensor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def test_consistency(self, prec, has_local, has_global) -> None:
8585
label_name=label_name,
8686
pref_atomic=1.0 if has_local else 0.0,
8787
pref=1.0 if has_global else 0.0,
88-
).to(self.device)
88+
)
8989

9090
model_pred, label = _make_data(
9191
rng,
@@ -158,7 +158,7 @@ def test_with_atomic_weight(self, prec) -> None:
158158
pref_atomic=1.0,
159159
pref=1.0,
160160
enable_atomic_weight=True,
161-
).to(self.device)
161+
)
162162

163163
model_pred, label = _make_data(
164164
rng,

0 commit comments

Comments
 (0)