diff --git a/deepmd/pt/utils/tabulate.py b/deepmd/pt/utils/tabulate.py index 930292db58..8a3ae773b7 100644 --- a/deepmd/pt/utils/tabulate.py +++ b/deepmd/pt/utils/tabulate.py @@ -1,9 +1,9 @@ # SPDX-License-Identifier: LGPL-3.0-or-later """PyTorch-specific DPTabulate wrapper. -Inherits the numpy math from ``deepmd.utils.tabulate_math.DPTabulate`` -and adds torch-specific ``_convert_numpy_to_tensor`` and -``_get_descrpt_type`` (isinstance checks against PT descriptor classes). +Inherits the shared Array API tabulation math from +``deepmd.utils.tabulate_math.DPTabulate`` and adds torch-specific +backend selection, tensor conversion, and descriptor type detection. """ from typing import ( @@ -31,8 +31,8 @@ class DPTabulate(DPTabulateBase): r"""PyTorch tabulation wrapper. - Accepts a PT ``ActivationFn`` module and delegates all math to the - numpy base class. Only overrides tensor conversion and descriptor + Accepts a PT ``ActivationFn`` module and delegates all shared math to the + base class. Overrides backend selection, tensor conversion, and descriptor type detection. Parameters @@ -91,6 +91,10 @@ def _get_descrpt_type(self) -> str: return "T_TEBD" raise RuntimeError(f"Unsupported descriptor {self.descrpt}") + def _get_math_backend_sample(self) -> Any: + """Run shared tabulation math on the current torch device.""" + return torch.empty((), dtype=torch.float64, device=env.DEVICE) + def _convert_numpy_to_tensor(self) -> None: """Convert self.data from np.ndarray to torch.Tensor.""" self._convert_numpy_float_to_int() diff --git a/deepmd/pt_expt/utils/tabulate.py b/deepmd/pt_expt/utils/tabulate.py index 4cda06d9b9..1cc3b3912d 100644 --- a/deepmd/pt_expt/utils/tabulate.py +++ b/deepmd/pt_expt/utils/tabulate.py @@ -1,9 +1,9 @@ # SPDX-License-Identifier: LGPL-3.0-or-later """DPTabulate for the pt_expt backend. -Inherits the numpy math from ``deepmd.utils.tabulate_math.DPTabulate`` -and overrides ``_convert_numpy_to_tensor`` for torch tensor conversion -and ``_get_descrpt_type`` for serialization-based type detection. +Inherits the shared Array API tabulation math from +``deepmd.utils.tabulate_math.DPTabulate`` and overrides backend +selection, tensor conversion, and descriptor type detection. No dependency on the pt backend. """ @@ -70,6 +70,16 @@ def _get_descrpt_type(self) -> str: raise RuntimeError(f"Unsupported descriptor type: {type_str}") return descrpt_type + def _get_math_backend_sample(self) -> Any: + """Run shared tabulation math on the pt_expt torch device.""" + import torch + + from deepmd.pt_expt.utils.env import ( + DEVICE, + ) + + return torch.empty((), dtype=torch.float64, device=DEVICE) + def _convert_numpy_to_tensor(self) -> None: """Convert self.data from np.ndarray to torch.Tensor.""" import torch diff --git a/deepmd/utils/tabulate_math.py b/deepmd/utils/tabulate_math.py index 67b56b311e..93fe903e12 100644 --- a/deepmd/utils/tabulate_math.py +++ b/deepmd/utils/tabulate_math.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -"""Backend-agnostic tabulation math using numpy. +"""Backend-agnostic tabulation math using the Array API where possible. Provides the pure-math functions for model compression tabulation: activation derivatives, chain-rule derivative propagation, and @@ -14,8 +14,12 @@ Any, ) +import array_api_compat import numpy as np +from deepmd.dpmodel.common import ( + to_numpy_array, +) from deepmd.dpmodel.utils.network import ( get_activation_fn, ) @@ -44,58 +48,67 @@ } -# ---- Activation derivatives (numpy) ---- +# ---- Activation derivatives (Array API compatible) ---- -def _stable_sigmoid(xbar: np.ndarray) -> np.ndarray: +def _stable_sigmoid(xbar: Any) -> Any: """Compute sigmoid without overflow for large-magnitude inputs.""" + xp = array_api_compat.array_namespace(xbar) positive = xbar >= 0 - exp_neg_abs = np.exp(np.where(positive, -xbar, xbar)) - return np.where( + exp_neg_abs = xp.exp(xp.where(positive, -xbar, xbar)) + return xp.where( positive, 1.0 / (1.0 + exp_neg_abs), exp_neg_abs / (1.0 + exp_neg_abs), ) -def grad(xbar: np.ndarray, y: np.ndarray, functype: int) -> np.ndarray: +def _repeat_flattened_weight_prefix(w: Any, rows: int, cols: int) -> Any: + """Repeat the flattened weight prefix row-wise in an Array API way.""" + xp = array_api_compat.array_namespace(w) + w_flat = xp.reshape(w, (-1,))[:cols] + w_flat = xp.reshape(w_flat, (1, cols)) + return xp.broadcast_to(w_flat, (rows, cols)) + + +def grad(xbar: Any, y: Any, functype: int) -> Any: """First derivative of the activation function.""" + xp = array_api_compat.array_namespace(xbar, y) if functype == 0: - return np.ones_like(xbar) + return xp.ones_like(xbar) elif functype == 1: return 1 - y * y elif functype == 2: - var = np.tanh(SQRT_2_PI * (xbar + GGELU * xbar**3)) + var = xp.tanh(SQRT_2_PI * (xbar + GGELU * xbar**3)) return ( 0.5 * SQRT_2_PI * xbar * (1 - var**2) * (3 * GGELU * xbar**2 + 1) + 0.5 * var + 0.5 ) elif functype == 3: - return np.where(xbar > 0, np.ones_like(xbar), np.zeros_like(xbar)) + return xp.astype(xbar > 0, xbar.dtype) elif functype == 4: - return np.where( - (xbar > 0) & (xbar < 6), np.ones_like(xbar), np.zeros_like(xbar) - ) + return xp.astype((xbar > 0) & (xbar < 6), xbar.dtype) elif functype == 5: return _stable_sigmoid(xbar) elif functype == 6: return y * (1 - y) elif functype == 7: - sig = 1.0 / (1.0 + np.exp(-xbar)) + sig = _stable_sigmoid(xbar) return sig + xbar * sig * (1 - sig) else: raise ValueError(f"Unsupported function type: {functype}") -def grad_grad(xbar: np.ndarray, y: np.ndarray, functype: int) -> np.ndarray: +def grad_grad(xbar: Any, y: Any, functype: int) -> Any: """Second derivative of the activation function.""" + xp = array_api_compat.array_namespace(xbar, y) if functype == 0: - return np.zeros_like(xbar) + return xp.zeros_like(xbar) elif functype == 1: return -2 * y * (1 - y * y) elif functype == 2: - var1 = np.tanh(SQRT_2_PI * (xbar + GGELU * xbar**3)) + var1 = xp.tanh(SQRT_2_PI * (xbar + GGELU * xbar**3)) var2 = SQRT_2_PI * (1 - var1**2) * (3 * GGELU * xbar**2 + 1) return ( 3 * GGELU * SQRT_2_PI * xbar**2 * (1 - var1**2) @@ -103,26 +116,24 @@ def grad_grad(xbar: np.ndarray, y: np.ndarray, functype: int) -> np.ndarray: + var2 ) elif functype in [3, 4]: - return np.zeros_like(xbar) + return xp.zeros_like(xbar) elif functype == 5: sig = _stable_sigmoid(xbar) return sig * (1 - sig) elif functype == 6: return y * (1 - y) * (1 - 2 * y) elif functype == 7: - sig = 1.0 / (1.0 + np.exp(-xbar)) + sig = _stable_sigmoid(xbar) d_sig = sig * (1 - sig) return 2 * d_sig + xbar * d_sig * (1 - 2 * sig) else: raise ValueError(f"Unsupported function type: {functype}") -# ---- Chain-rule derivative propagation (numpy) ---- +# ---- Chain-rule derivative propagation (Array API compatible) ---- -def unaggregated_dy_dx_s( - y: np.ndarray, w: np.ndarray, xbar: np.ndarray, functype: int -) -> np.ndarray: +def unaggregated_dy_dx_s(y: Any, w: Any, xbar: Any, functype: int) -> Any: """First derivative for the first layer (scalar input).""" if y.ndim != 2: raise ValueError("Dim of input y should be 2") @@ -132,18 +143,17 @@ def unaggregated_dy_dx_s( raise ValueError("Dim of input xbar should be 2") grad_xbar_y = grad(xbar, y, functype) - w_flat = np.ravel(w)[: y.shape[1]] - w_rep = np.tile(w_flat, (y.shape[0], 1)) + w_rep = _repeat_flattened_weight_prefix(w, y.shape[0], y.shape[1]) return grad_xbar_y * w_rep def unaggregated_dy2_dx_s( - y: np.ndarray, - dy: np.ndarray, - w: np.ndarray, - xbar: np.ndarray, + y: Any, + dy: Any, + w: Any, + xbar: Any, functype: int, -) -> np.ndarray: +) -> Any: """Second derivative for the first layer (scalar input).""" if y.ndim != 2: raise ValueError("Dim of input y should be 2") @@ -155,19 +165,19 @@ def unaggregated_dy2_dx_s( raise ValueError("Dim of input xbar should be 2") gg = grad_grad(xbar, y, functype) - w_flat = np.ravel(w)[: y.shape[1]] - w_rep = np.tile(w_flat, (y.shape[0], 1)) + w_rep = _repeat_flattened_weight_prefix(w, y.shape[0], y.shape[1]) return gg * w_rep * w_rep def unaggregated_dy_dx( - z: np.ndarray, - w: np.ndarray, - dy_dx: np.ndarray, - ybar: np.ndarray, + z: Any, + w: Any, + dy_dx: Any, + ybar: Any, functype: int, -) -> np.ndarray: +) -> Any: """First derivative for subsequent layers.""" + xp = array_api_compat.array_namespace(z, w, dy_dx, ybar) if z.ndim != 2: raise ValueError("z must have 2 dimensions") if w.ndim != 2: @@ -181,28 +191,30 @@ def unaggregated_dy_dx( size = w.shape[0] grad_ybar_z = grad(ybar, z, functype) - dy_dx = np.ravel(dy_dx)[: length * size].reshape(length, size) - accumulator = dy_dx @ w + dy_dx = xp.reshape(dy_dx, (-1,))[: length * size] + dy_dx = xp.reshape(dy_dx, (length, size)) + accumulator = xp.matmul(dy_dx, w) dz_drou = grad_ybar_z * accumulator if width == size: dz_drou += dy_dx if width == 2 * size: - dy_dx = np.concatenate((dy_dx, dy_dx), axis=1) + dy_dx = xp.concat((dy_dx, dy_dx), axis=1) dz_drou += dy_dx return dz_drou def unaggregated_dy2_dx( - z: np.ndarray, - w: np.ndarray, - dy_dx: np.ndarray, - dy2_dx: np.ndarray, - ybar: np.ndarray, + z: Any, + w: Any, + dy_dx: Any, + dy2_dx: Any, + ybar: Any, functype: int, -) -> np.ndarray: +) -> Any: """Second derivative for subsequent layers.""" + xp = array_api_compat.array_namespace(z, w, dy_dx, dy2_dx, ybar) if z.ndim != 2: raise ValueError("z must have 2 dimensions") if w.ndim != 2: @@ -220,68 +232,62 @@ def unaggregated_dy2_dx( grad_ybar_z = grad(ybar, z, functype) gg = grad_grad(ybar, z, functype) - dy2_dx = np.ravel(dy2_dx)[: length * size].reshape(length, size) - dy_dx = np.ravel(dy_dx)[: length * size].reshape(length, size) + dy2_dx = xp.reshape(dy2_dx, (-1,))[: length * size] + dy2_dx = xp.reshape(dy2_dx, (length, size)) + dy_dx = xp.reshape(dy_dx, (-1,))[: length * size] + dy_dx = xp.reshape(dy_dx, (length, size)) - acc1 = dy2_dx @ w - acc2 = dy_dx @ w + acc1 = xp.matmul(dy2_dx, w) + acc2 = xp.matmul(dy_dx, w) dz_drou = grad_ybar_z * acc1 + gg * acc2 * acc2 if width == size: dz_drou += dy2_dx if width == 2 * size: - dy2_dx = np.concatenate((dy2_dx, dy2_dx), axis=1) + dy2_dx = xp.concat((dy2_dx, dy2_dx), axis=1) dz_drou += dy2_dx return dz_drou -# ---- DPTabulate with numpy math ---- +# ---- DPTabulate with Array API math ---- class DPTabulate(BaseTabulate): - r"""Backend-agnostic tabulation using numpy. + r"""Backend-agnostic tabulation using Array API compatible math. Compress a model by tabulating the embedding-net. The table is composed - of fifth-order polynomial coefficients assembled from two sub-tables. - - Parameters - ---------- - descrpt - Descriptor of the original model. - neuron - Number of neurons in each hidden layer of the embedding net. - type_one_side - Try to build N_types tables. Otherwise, building N_types^2 tables. - exclude_types - Excluded type pairs with no interaction. - activation_fn_name - Name of the activation function (e.g. "tanh", "gelu", "relu"). + of fifth-order polynomial coefficients fitted to the embedding-net output + and its derivatives over intervals of the environment matrix. """ def __init__( self, descrpt: Any, neuron: list[int], - type_one_side: bool = False, + type_one_side: bool, exclude_types: list[list[int]] | None = None, - activation_fn_name: str = "tanh", + activation_fn: str = "tanh", + suffix: str = "", + *, + activation_fn_name: str | None = None, ) -> None: exclude_types = [] if exclude_types is None else exclude_types - super().__init__( - descrpt, - neuron, - type_one_side, - exclude_types, - ) - self._activation_fn = get_activation_fn(activation_fn_name) - activation_fn_name = activation_fn_name.lower() - if activation_fn_name not in ACTIVATION_TO_FUNCTYPE: - raise RuntimeError(f"Unknown activation function: {activation_fn_name}") - self.functype = ACTIVATION_TO_FUNCTYPE[activation_fn_name] + if activation_fn_name is not None: + activation_fn = activation_fn_name + super().__init__(descrpt, neuron, type_one_side, exclude_types) self.descrpt_type = self._get_descrpt_type() + self.neuron = neuron + self.type_one_side = type_one_side + self.exclude_types = exclude_types + self.suffix = suffix + self.activation_fn = activation_fn + self.functype = ACTIVATION_TO_FUNCTYPE.get(activation_fn, -1) + if self.functype == -1: + raise ValueError(f"Unsupported activation function: {activation_fn}") + self._activation_fn = get_activation_fn(activation_fn) supported_descrpt_type = ("Atten", "A", "T", "T_TEBD", "R") if self.descrpt_type in supported_descrpt_type: @@ -310,141 +316,146 @@ def __init__( self.data_type = self._get_data_type() self.last_layer_size = self._get_last_layer_size() + def _get_math_backend_sample(self) -> Any: + """Return a sample array choosing the execution backend for math ops.""" + return np.empty((), dtype=self.data_type) + + @cached_property + def _math_backend_sample(self) -> Any: + return self._get_math_backend_sample() + + @cached_property + def _math_backend_device(self) -> Any: + return array_api_compat.device(self._math_backend_sample) + + def _backend_asarray(self, value: Any) -> Any: + xp = array_api_compat.array_namespace(self._math_backend_sample) + return xp.asarray(value, device=self._math_backend_device) + + @cached_property + def _matrix_backend(self) -> dict[str, list[Any]]: + matrix = { + layer: [self._backend_asarray(value) for value in values] + for layer, values in self.matrix.items() + } + self.matrix = None + return matrix + + @cached_property + def _bias_backend(self) -> dict[str, list[Any]]: + bias = { + layer: [self._backend_asarray(value) for value in values] + for layer, values in self.bias.items() + } + self.bias = None + return bias + def _make_data(self, xx: np.ndarray, idx: int) -> Any: """Forward pass through embedding net with derivative computation.""" - xx = xx.reshape(-1, 1) + xp = array_api_compat.array_namespace(self._math_backend_sample) + xx = xp.reshape(self._backend_asarray(xx), (-1, 1)) for layer in range(self.layer_size): + matrix = self._matrix_backend["layer_" + str(layer + 1)][idx] + bias = self._bias_backend["layer_" + str(layer + 1)][idx] if layer == 0: - xbar = ( - np.matmul(xx, self.matrix["layer_" + str(layer + 1)][idx]) - + self.bias["layer_" + str(layer + 1)][idx] - ) + xbar = xp.matmul(xx, matrix) + bias if self.neuron[0] == 1: - yy = ( - self._layer_0( - xx, - self.matrix["layer_" + str(layer + 1)][idx], - self.bias["layer_" + str(layer + 1)][idx], - ) - + xx - ) + yy = self._layer_0(xx, matrix, bias) + xx dy = unaggregated_dy_dx_s( yy - xx, - self.matrix["layer_" + str(layer + 1)][idx], + matrix, xbar, self.functype, - ) + np.ones((1, 1), dtype=yy.dtype) + ) + xp.ones( + (1, 1), dtype=yy.dtype, device=array_api_compat.device(yy) + ) dy2 = unaggregated_dy2_dx_s( yy - xx, dy, - self.matrix["layer_" + str(layer + 1)][idx], + matrix, xbar, self.functype, ) elif self.neuron[0] == 2: - tt, yy = self._layer_1( - xx, - self.matrix["layer_" + str(layer + 1)][idx], - self.bias["layer_" + str(layer + 1)][idx], - ) + tt, yy = self._layer_1(xx, matrix, bias) dy = unaggregated_dy_dx_s( yy - tt, - self.matrix["layer_" + str(layer + 1)][idx], + matrix, xbar, self.functype, - ) + np.ones((1, 2), dtype=yy.dtype) + ) + xp.ones( + (1, 2), dtype=yy.dtype, device=array_api_compat.device(yy) + ) dy2 = unaggregated_dy2_dx_s( yy - tt, dy, - self.matrix["layer_" + str(layer + 1)][idx], + matrix, xbar, self.functype, ) else: - yy = self._layer_0( - xx, - self.matrix["layer_" + str(layer + 1)][idx], - self.bias["layer_" + str(layer + 1)][idx], - ) + yy = self._layer_0(xx, matrix, bias) dy = unaggregated_dy_dx_s( yy, - self.matrix["layer_" + str(layer + 1)][idx], + matrix, xbar, self.functype, ) dy2 = unaggregated_dy2_dx_s( yy, dy, - self.matrix["layer_" + str(layer + 1)][idx], + matrix, xbar, self.functype, ) else: - ybar = ( - np.matmul(yy, self.matrix["layer_" + str(layer + 1)][idx]) - + self.bias["layer_" + str(layer + 1)][idx] - ) + ybar = xp.matmul(yy, matrix) + bias if self.neuron[layer] == self.neuron[layer - 1]: - zz = ( - self._layer_0( - yy, - self.matrix["layer_" + str(layer + 1)][idx], - self.bias["layer_" + str(layer + 1)][idx], - ) - + yy - ) + zz = self._layer_0(yy, matrix, bias) + yy dz = unaggregated_dy_dx( zz - yy, - self.matrix["layer_" + str(layer + 1)][idx], + matrix, dy, ybar, self.functype, ) dy2 = unaggregated_dy2_dx( zz - yy, - self.matrix["layer_" + str(layer + 1)][idx], + matrix, dy, dy2, ybar, self.functype, ) elif self.neuron[layer] == 2 * self.neuron[layer - 1]: - tt, zz = self._layer_1( - yy, - self.matrix["layer_" + str(layer + 1)][idx], - self.bias["layer_" + str(layer + 1)][idx], - ) + tt, zz = self._layer_1(yy, matrix, bias) dz = unaggregated_dy_dx( zz - tt, - self.matrix["layer_" + str(layer + 1)][idx], + matrix, dy, ybar, self.functype, ) dy2 = unaggregated_dy2_dx( zz - tt, - self.matrix["layer_" + str(layer + 1)][idx], + matrix, dy, dy2, ybar, self.functype, ) else: - zz = self._layer_0( - yy, - self.matrix["layer_" + str(layer + 1)][idx], - self.bias["layer_" + str(layer + 1)][idx], - ) + zz = self._layer_0(yy, matrix, bias) dz = unaggregated_dy_dx( zz, - self.matrix["layer_" + str(layer + 1)][idx], + matrix, dy, ybar, self.functype, ) dy2 = unaggregated_dy2_dx( zz, - self.matrix["layer_" + str(layer + 1)][idx], + matrix, dy, dy2, ybar, @@ -453,19 +464,19 @@ def _make_data(self, xx: np.ndarray, idx: int) -> Any: dy = dz yy = zz - vv = yy.astype(self.data_type) - dd = dy.astype(self.data_type) - d2 = dy2.astype(self.data_type) + vv = to_numpy_array(yy).astype(self.data_type) + dd = to_numpy_array(dy).astype(self.data_type) + d2 = to_numpy_array(dy2).astype(self.data_type) return vv, dd, d2 - def _layer_0(self, x: np.ndarray, w: np.ndarray, b: np.ndarray) -> np.ndarray: - return self._activation_fn(np.matmul(x, w) + b) + def _layer_0(self, x: Any, w: Any, b: Any) -> Any: + xp = array_api_compat.array_namespace(x, w, b) + return self._activation_fn(xp.matmul(x, w) + b) - def _layer_1( - self, x: np.ndarray, w: np.ndarray, b: np.ndarray - ) -> tuple[np.ndarray, np.ndarray]: - t = np.concatenate([x, x], axis=1) - return t, self._activation_fn(np.matmul(x, w) + b) + t + def _layer_1(self, x: Any, w: Any, b: Any) -> tuple[Any, Any]: + xp = array_api_compat.array_namespace(x, w, b) + t = xp.concat([x, x], axis=1) + return t, self._activation_fn(xp.matmul(x, w) + b) + t def _get_descrpt_type(self) -> str: """Determine descriptor type from serialized data.""" diff --git a/source/tests/common/test_tabulate_math_array_api.py b/source/tests/common/test_tabulate_math_array_api.py new file mode 100644 index 0000000000..59c1802c56 --- /dev/null +++ b/source/tests/common/test_tabulate_math_array_api.py @@ -0,0 +1,96 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import sys +import unittest + +import numpy as np + +from deepmd.dpmodel.common import ( + to_numpy_array, +) +from deepmd.utils import tabulate_math as tm + +from ..consistent.common import ( + INSTALLED_ARRAY_API_STRICT, +) + +if INSTALLED_ARRAY_API_STRICT: + from ..array_api_strict.common import ( + to_array_api_strict_array, + ) + + +class TestTabulateMathArrayAPI(unittest.TestCase): + def setUp(self) -> None: + self.xbar_np = np.array([[0.1, -0.2], [0.3, 0.4]], dtype=np.float64) + self.y_np = np.tanh(self.xbar_np) + self.w_np = np.array([[0.5, 0.6], [0.7, 0.8]], dtype=np.float64) + + @unittest.skipUnless( + INSTALLED_ARRAY_API_STRICT, "array_api_strict is not installed" + ) + @unittest.skipUnless( + sys.version_info >= (3, 9), "array_api_strict doesn't support Python<=3.8" + ) + def test_chain_rule_helpers_array_api_strict_consistent_with_numpy(self) -> None: + xbar = to_array_api_strict_array(self.xbar_np) + y = to_array_api_strict_array(self.y_np) + w = to_array_api_strict_array(self.w_np) + + dy_s = tm.unaggregated_dy_dx_s(y, w, xbar, 1) + dy2_s = tm.unaggregated_dy2_dx_s(y, dy_s, w, xbar, 1) + dy = tm.unaggregated_dy_dx(y, w, dy_s, xbar, 1) + dy2 = tm.unaggregated_dy2_dx(y, w, dy_s, dy2_s, xbar, 1) + + dy_s_ref = tm.unaggregated_dy_dx_s(self.y_np, self.w_np, self.xbar_np, 1) + dy2_s_ref = tm.unaggregated_dy2_dx_s( + self.y_np, + dy_s_ref, + self.w_np, + self.xbar_np, + 1, + ) + dy_ref = tm.unaggregated_dy_dx( + self.y_np, + self.w_np, + dy_s_ref, + self.xbar_np, + 1, + ) + dy2_ref = tm.unaggregated_dy2_dx( + self.y_np, + self.w_np, + dy_s_ref, + dy2_s_ref, + self.xbar_np, + 1, + ) + + np.testing.assert_allclose(to_numpy_array(dy_s), dy_s_ref, atol=1e-10) + np.testing.assert_allclose(to_numpy_array(dy2_s), dy2_s_ref, atol=1e-10) + np.testing.assert_allclose(to_numpy_array(dy), dy_ref, atol=1e-10) + np.testing.assert_allclose(to_numpy_array(dy2), dy2_ref, atol=1e-10) + + @unittest.skipUnless( + INSTALLED_ARRAY_API_STRICT, "array_api_strict is not installed" + ) + @unittest.skipUnless( + sys.version_info >= (3, 9), "array_api_strict doesn't support Python<=3.8" + ) + def test_stable_sigmoid_and_silu_grad_array_api_strict_consistent_with_numpy( + self, + ) -> None: + xbar_np = np.array([[-1000.0, -1.0, 0.0, 1.0, 1000.0]], dtype=np.float64) + xbar = to_array_api_strict_array(xbar_np) + + stable = tm._stable_sigmoid(xbar) + silu_grad = tm.grad(xbar, stable, 7) + + stable_ref = tm._stable_sigmoid(xbar_np) + silu_grad_ref = tm.grad(xbar_np, stable_ref, 7) + + np.testing.assert_allclose(to_numpy_array(stable), stable_ref, atol=1e-10) + np.testing.assert_allclose( + to_numpy_array(silu_grad), + silu_grad_ref, + atol=1e-10, + )