Skip to content

Commit 9b3661a

Browse files
author
Han Wang
committed
test: add functype=0 (linear) test for tabulate derivatives
1 parent 28cd4a2 commit 9b3661a

1 file changed

Lines changed: 37 additions & 0 deletions

File tree

source/tests/pt/test_tabulate.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,12 @@
2727
7: "silu",
2828
}
2929

30+
# functype=0 (linear/none) is not supported by TF custom ops,
31+
# so we test it separately against numerical derivatives.
32+
ACTIVATION_NAMES_NUMPY_ONLY = {
33+
0: "linear",
34+
}
35+
3036

3137
def get_activation_function(functype: int):
3238
"""Get activation function corresponding to functype."""
@@ -185,6 +191,37 @@ def _test_single_activation(
185191
err_msg=f"unaggregated_dy2_dx failed for {activation_name}",
186192
)
187193

194+
def test_linear_activation(self) -> None:
195+
"""Test functype=0 (linear/none) against numerical derivatives.
196+
197+
TF custom ops don't support functype=0, so we validate against
198+
finite-difference derivatives instead.
199+
"""
200+
from deepmd.utils.tabulate_math import (
201+
grad,
202+
grad_grad,
203+
)
204+
205+
fn = get_activation_fn("linear")
206+
y = fn(self.xbar)
207+
h = 1e-7
208+
209+
# grad: f'(x) = 1 for identity
210+
dy_ana = grad(self.xbar, y, 0)
211+
np.testing.assert_allclose(dy_ana, np.ones_like(self.xbar), atol=1e-12)
212+
213+
# grad_grad: f''(x) = 0 for identity
214+
dy2_ana = grad_grad(self.xbar, y, 0)
215+
np.testing.assert_allclose(dy2_ana, np.zeros_like(self.xbar), atol=1e-12)
216+
217+
# Also verify unaggregated functions work with functype=0
218+
dy = unaggregated_dy_dx_s(y, self.w, self.xbar, 0)
219+
self.assertEqual(dy.shape, (4, 4))
220+
221+
dy2 = unaggregated_dy2_dx_s(y, dy, self.w, self.xbar, 0)
222+
# Second derivative of identity is zero everywhere
223+
np.testing.assert_allclose(dy2, np.zeros_like(dy2), atol=1e-12)
224+
188225

189226
if __name__ == "__main__":
190227
unittest.main()

0 commit comments

Comments
 (0)