Skip to content

Commit efc27cf

Browse files
authored
refactor: use Array API-compatible tabulate math (#5366)
## Summary - port `deepmd/utils/tabulate_math.py` helpers from NumPy-only operators to Array API-compatible operators - let the PT / pt_expt tabulation wrappers choose a torch sample backend so the shared math path can execute on torch devices - add a regression test that verifies the helper functions use the provided Array API namespace rather than silently falling back to NumPy ## Testing - `. .venv/bin/activate && PYTHONPATH=. pytest -q source/tests/common/test_tabulate_math_array_api.py` - `python3 -m py_compile deepmd/utils/tabulate_math.py deepmd/pt/utils/tabulate.py deepmd/pt_expt/utils/tabulate.py source/tests/common/test_tabulate_math_array_api.py` Closes #5352 Authored by OpenClaw (model: gpt-5.4) <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Device-aware sampling so tabulation math runs on the active compute device. * **Refactor** * Tabulation math migrated to Array API–compatible implementation with unified backend selection, device-aware tensor handling, and updated derivative propagation. * **Tests** * Added tests validating gradient/chain-rule computations and numerical stability versus NumPy under the Array API backend. * **Documentation** * Module docs updated to reflect Array API compatibility. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent b02fd91 commit efc27cf

File tree

4 files changed

+277
-156
lines changed

4 files changed

+277
-156
lines changed

deepmd/pt/utils/tabulate.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22
"""PyTorch-specific DPTabulate wrapper.
33
4-
Inherits the numpy math from ``deepmd.utils.tabulate_math.DPTabulate``
5-
and adds torch-specific ``_convert_numpy_to_tensor`` and
6-
``_get_descrpt_type`` (isinstance checks against PT descriptor classes).
4+
Inherits the shared Array API tabulation math from
5+
``deepmd.utils.tabulate_math.DPTabulate`` and adds torch-specific
6+
backend selection, tensor conversion, and descriptor type detection.
77
"""
88

99
from typing import (
@@ -31,8 +31,8 @@
3131
class DPTabulate(DPTabulateBase):
3232
r"""PyTorch tabulation wrapper.
3333
34-
Accepts a PT ``ActivationFn`` module and delegates all math to the
35-
numpy base class. Only overrides tensor conversion and descriptor
34+
Accepts a PT ``ActivationFn`` module and delegates all shared math to the
35+
base class. Overrides backend selection, tensor conversion, and descriptor
3636
type detection.
3737
3838
Parameters
@@ -91,6 +91,10 @@ def _get_descrpt_type(self) -> str:
9191
return "T_TEBD"
9292
raise RuntimeError(f"Unsupported descriptor {self.descrpt}")
9393

94+
def _get_math_backend_sample(self) -> Any:
95+
"""Run shared tabulation math on the current torch device."""
96+
return torch.empty((), dtype=torch.float64, device=env.DEVICE)
97+
9498
def _convert_numpy_to_tensor(self) -> None:
9599
"""Convert self.data from np.ndarray to torch.Tensor."""
96100
self._convert_numpy_float_to_int()

deepmd/pt_expt/utils/tabulate.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22
"""DPTabulate for the pt_expt backend.
33
4-
Inherits the numpy math from ``deepmd.utils.tabulate_math.DPTabulate``
5-
and overrides ``_convert_numpy_to_tensor`` for torch tensor conversion
6-
and ``_get_descrpt_type`` for serialization-based type detection.
4+
Inherits the shared Array API tabulation math from
5+
``deepmd.utils.tabulate_math.DPTabulate`` and overrides backend
6+
selection, tensor conversion, and descriptor type detection.
77
No dependency on the pt backend.
88
"""
99

@@ -70,6 +70,16 @@ def _get_descrpt_type(self) -> str:
7070
raise RuntimeError(f"Unsupported descriptor type: {type_str}")
7171
return descrpt_type
7272

73+
def _get_math_backend_sample(self) -> Any:
74+
"""Run shared tabulation math on the pt_expt torch device."""
75+
import torch
76+
77+
from deepmd.pt_expt.utils.env import (
78+
DEVICE,
79+
)
80+
81+
return torch.empty((), dtype=torch.float64, device=DEVICE)
82+
7383
def _convert_numpy_to_tensor(self) -> None:
7484
"""Convert self.data from np.ndarray to torch.Tensor."""
7585
import torch

0 commit comments

Comments
 (0)