Skip to content

Commit f774cd2

Browse files
author
Han Wang
committed
test(pt_expt): port silut activation + repformers accessors from deepmodeling#5393
Add silut/custom_silu support to _torch_activation using native torch ops (sigmoid, tanh, where) so the custom silu stays traceable by make_fx / torch.export. Cross-backend consistency tests cover multiple thresholds across the silu/tanh branches, and a pt_expt unit file exercises default/custom threshold, gradient flow, make_fx, and torch.export. Also port DescrptBlockRepformers accessor tests (get_rcut_smth, get_env_protection). The underlying accessor methods already exist on this branch; these tests guard against regressions.
1 parent aabb710 commit f774cd2

4 files changed

Lines changed: 224 additions & 0 deletions

File tree

deepmd/pt_expt/utils/network.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import math
23
from typing import (
34
Any,
45
ClassVar,
@@ -182,6 +183,14 @@ def _torch_activation(x: torch.Tensor, name: str) -> torch.Tensor:
182183
return torch.sigmoid(x)
183184
elif name == "silu":
184185
return torch.nn.functional.silu(x)
186+
elif name.startswith("silut") or name.startswith("custom_silu"):
187+
threshold = float(name.split(":")[-1]) if ":" in name else 3.0
188+
sig_t = 1.0 / (1.0 + math.exp(-threshold))
189+
slope = sig_t + threshold * sig_t * (1.0 - sig_t)
190+
const = threshold * sig_t
191+
silu = x * torch.sigmoid(x)
192+
tanh_branch = torch.tanh(slope * (x - threshold)) + const
193+
return torch.where(x < threshold, silu, tanh_branch)
185194
elif name in ("none", "linear"):
186195
return x
187196
else:

source/tests/common/dpmodel/test_descriptor_dpa2.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
RepformerArgs,
1111
RepinitArgs,
1212
)
13+
from deepmd.dpmodel.descriptor.repformers import (
14+
DescrptBlockRepformers,
15+
)
1316

1417
from ...seed import (
1518
GLOBAL_SEED,
@@ -69,3 +72,36 @@ def test_self_consistency(
6972
for ii in [0, 1, 2, 3, 4]:
7073
np.testing.assert_equal(mm0[ii].shape, desired_shape[ii])
7174
np.testing.assert_allclose(mm0[ii], mm1[ii])
75+
76+
77+
class TestDescrptBlockRepformersAccessors(unittest.TestCase):
78+
def test_get_rcut_smth(self) -> None:
79+
block = DescrptBlockRepformers(
80+
rcut=6.0,
81+
rcut_smth=5.0,
82+
sel=40,
83+
ntypes=2,
84+
nlayers=3,
85+
)
86+
self.assertEqual(block.get_rcut_smth(), 5.0)
87+
88+
def test_get_env_protection(self) -> None:
89+
block = DescrptBlockRepformers(
90+
rcut=6.0,
91+
rcut_smth=5.0,
92+
sel=40,
93+
ntypes=2,
94+
nlayers=3,
95+
env_protection=1.0,
96+
)
97+
self.assertEqual(block.get_env_protection(), 1.0)
98+
99+
def test_get_env_protection_default(self) -> None:
100+
block = DescrptBlockRepformers(
101+
rcut=6.0,
102+
rcut_smth=5.0,
103+
sel=40,
104+
ntypes=2,
105+
nlayers=3,
106+
)
107+
self.assertEqual(block.get_env_protection(), 0.0)

source/tests/consistent/test_activation.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
INSTALLED_JAX,
2020
INSTALLED_PD,
2121
INSTALLED_PT,
22+
INSTALLED_PT_EXPT,
2223
INSTALLED_TF,
2324
parameterized,
2425
)
@@ -29,6 +30,13 @@
2930
from deepmd.pt.utils.utils import (
3031
to_torch_tensor,
3132
)
33+
if INSTALLED_PT_EXPT:
34+
import torch
35+
36+
from deepmd.pt_expt.utils.env import DEVICE as PT_EXPT_DEVICE
37+
from deepmd.pt_expt.utils.network import (
38+
_torch_activation,
39+
)
3240
if INSTALLED_TF:
3341
from deepmd.tf.common import get_activation_func as get_activation_fn_tf
3442
from deepmd.tf.env import (
@@ -98,3 +106,54 @@ def test_pd_consistent_with_ref(self):
98106
ActivationFn_pd(self.activation)(to_paddle_tensor(self.random_input))
99107
)
100108
np.testing.assert_allclose(self.ref, test, atol=1e-10)
109+
110+
@unittest.skipUnless(INSTALLED_PT_EXPT, "PyTorch Exportable is not installed")
111+
def test_pt_expt_consistent_with_ref(self) -> None:
112+
if INSTALLED_PT_EXPT:
113+
x = torch.tensor(
114+
self.random_input, dtype=torch.float64, device=PT_EXPT_DEVICE
115+
)
116+
test = _torch_activation(x, self.activation).detach().numpy()
117+
np.testing.assert_allclose(self.ref, test, atol=1e-10)
118+
119+
120+
@parameterized(
121+
(
122+
"silut", # default threshold 3.0
123+
"silut:3.0", # explicit threshold 3.0
124+
"silut:10.0", # large threshold
125+
"custom_silu:5.0", # alias
126+
),
127+
)
128+
class TestSilutVariantsConsistent(unittest.TestCase):
129+
"""Cross-backend consistency for silut with different thresholds."""
130+
131+
def setUp(self) -> None:
132+
(self.activation,) = self.param
133+
# Parse threshold to build input that covers both branches
134+
threshold = (
135+
float(self.activation.split(":")[-1]) if ":" in self.activation else 3.0
136+
)
137+
rng = np.random.default_rng(GLOBAL_SEED)
138+
# Values below threshold (silu branch) and above threshold (tanh branch)
139+
below = rng.uniform(-threshold - 5, threshold - 0.1, size=(5, 10))
140+
above = rng.uniform(threshold + 0.1, threshold + 20, size=(5, 10))
141+
self.random_input = np.concatenate([below, above], axis=0)
142+
self.ref = get_activation_fn_dp(self.activation)(self.random_input)
143+
144+
@unittest.skipUnless(INSTALLED_PT, "PyTorch is not installed")
145+
def test_pt_consistent_with_ref(self) -> None:
146+
if INSTALLED_PT:
147+
test = torch_to_numpy(
148+
ActivationFn_pt(self.activation)(to_torch_tensor(self.random_input))
149+
)
150+
np.testing.assert_allclose(self.ref, test, atol=1e-10)
151+
152+
@unittest.skipUnless(INSTALLED_PT_EXPT, "PyTorch Exportable is not installed")
153+
def test_pt_expt_consistent_with_ref(self) -> None:
154+
if INSTALLED_PT_EXPT:
155+
x = torch.tensor(
156+
self.random_input, dtype=torch.float64, device=PT_EXPT_DEVICE
157+
)
158+
test = _torch_activation(x, self.activation).detach().numpy()
159+
np.testing.assert_allclose(self.ref, test, atol=1e-10)
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import numpy as np
3+
import torch
4+
from torch.fx.experimental.proxy_tensor import (
5+
make_fx,
6+
)
7+
8+
from deepmd.dpmodel.utils.network import (
9+
get_activation_fn,
10+
)
11+
from deepmd.pt_expt.utils.network import (
12+
_torch_activation,
13+
)
14+
15+
16+
class TestSilutActivation:
17+
"""Tests for silut activation in _torch_activation."""
18+
19+
def setup_method(self) -> None:
20+
# x values spanning both branches: below threshold and above
21+
self.x_np = np.array(
22+
[-5.0, -1.0, 0.0, 1.0, 2.5, 3.0, 5.0, 10.0, 15.0, 20.0],
23+
dtype=np.float64,
24+
)
25+
self.x_torch = torch.tensor(self.x_np, dtype=torch.float64)
26+
27+
def test_silut_with_threshold(self) -> None:
28+
"""silut:10.0 matches dpmodel numerically."""
29+
result = _torch_activation(self.x_torch, "silut:10.0")
30+
dp_fn = get_activation_fn("silut:10.0")
31+
expected = dp_fn(self.x_np)
32+
np.testing.assert_allclose(
33+
result.detach().numpy(), expected, rtol=1e-12, atol=1e-12
34+
)
35+
36+
def test_silut_default_threshold(self) -> None:
37+
"""Silut without parameter uses default threshold 3.0."""
38+
result = _torch_activation(self.x_torch, "silut")
39+
dp_fn = get_activation_fn("silut")
40+
expected = dp_fn(self.x_np)
41+
np.testing.assert_allclose(
42+
result.detach().numpy(), expected, rtol=1e-12, atol=1e-12
43+
)
44+
45+
def test_silut_custom_silu_alias(self) -> None:
46+
"""custom_silu:5.0 is an alias for silut:5.0."""
47+
result = _torch_activation(self.x_torch, "custom_silu:5.0")
48+
dp_fn = get_activation_fn("custom_silu:5.0")
49+
expected = dp_fn(self.x_np)
50+
np.testing.assert_allclose(
51+
result.detach().numpy(), expected, rtol=1e-12, atol=1e-12
52+
)
53+
54+
def test_silut_gradient(self) -> None:
55+
"""Gradient flows through both branches of silut."""
56+
x = self.x_torch.clone().requires_grad_(True)
57+
y = _torch_activation(x, "silut:3.0")
58+
loss = y.sum()
59+
loss.backward()
60+
grad = x.grad
61+
assert grad is not None
62+
# gradient should be finite everywhere
63+
assert torch.all(torch.isfinite(grad))
64+
# gradient should be non-zero for non-zero inputs
65+
nonzero_mask = self.x_np != 0.0
66+
assert torch.all(grad[nonzero_mask] != 0.0)
67+
68+
def test_silut_make_fx(self) -> None:
69+
"""make_fx can trace through silut activation."""
70+
71+
def fn(x: torch.Tensor) -> torch.Tensor:
72+
return _torch_activation(x, "silut:10.0")
73+
74+
traced = make_fx(fn)(self.x_torch)
75+
result = traced(self.x_torch)
76+
expected = _torch_activation(self.x_torch, "silut:10.0")
77+
np.testing.assert_allclose(
78+
result.detach().numpy(), expected.detach().numpy(), rtol=1e-12, atol=1e-12
79+
)
80+
81+
def test_silut_below_threshold_is_silu(self) -> None:
82+
"""Below threshold, silut equals silu exactly."""
83+
threshold = 10.0
84+
x_below = torch.tensor([-5.0, 0.0, 1.0, 5.0, 9.9], dtype=torch.float64)
85+
result = _torch_activation(x_below, "silut:10.0")
86+
silu = x_below * torch.sigmoid(x_below)
87+
np.testing.assert_allclose(
88+
result.detach().numpy(), silu.detach().numpy(), rtol=1e-14, atol=1e-14
89+
)
90+
91+
def test_silut_above_threshold_is_tanh_branch(self) -> None:
92+
"""Above threshold, silut equals tanh(slope*(x-T))+const."""
93+
import math
94+
95+
threshold = 3.0
96+
sig_t = 1.0 / (1.0 + math.exp(-threshold))
97+
slope = sig_t + threshold * sig_t * (1.0 - sig_t)
98+
const = threshold * sig_t
99+
100+
x_above = torch.tensor([3.5, 5.0, 10.0, 20.0], dtype=torch.float64)
101+
result = _torch_activation(x_above, "silut:3.0")
102+
expected = torch.tanh(slope * (x_above - threshold)) + const
103+
np.testing.assert_allclose(
104+
result.detach().numpy(), expected.detach().numpy(), rtol=1e-14, atol=1e-14
105+
)
106+
107+
def test_silut_export(self) -> None:
108+
"""torch.export.export can trace through silut activation."""
109+
110+
class SilutModule(torch.nn.Module):
111+
def forward(self, x: torch.Tensor) -> torch.Tensor:
112+
return _torch_activation(x, "silut:10.0")
113+
114+
mod = SilutModule()
115+
exported = torch.export.export(mod, (self.x_torch,))
116+
result = exported.module()(self.x_torch)
117+
expected = _torch_activation(self.x_torch, "silut:10.0")
118+
np.testing.assert_allclose(
119+
result.detach().numpy(), expected.detach().numpy(), rtol=1e-12, atol=1e-12
120+
)

0 commit comments

Comments
 (0)