Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions deepmd/pt/utils/tabulate.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
"""PyTorch-specific DPTabulate wrapper.

Inherits the numpy math from ``deepmd.utils.tabulate_math.DPTabulate``
and adds torch-specific ``_convert_numpy_to_tensor`` and
``_get_descrpt_type`` (isinstance checks against PT descriptor classes).
Inherits the shared Array API tabulation math from
``deepmd.utils.tabulate_math.DPTabulate`` and adds torch-specific
backend selection, tensor conversion, and descriptor type detection.
"""

from typing import (
Expand Down Expand Up @@ -31,8 +31,8 @@
class DPTabulate(DPTabulateBase):
r"""PyTorch tabulation wrapper.

Accepts a PT ``ActivationFn`` module and delegates all math to the
numpy base class. Only overrides tensor conversion and descriptor
Accepts a PT ``ActivationFn`` module and delegates all shared math to the
base class. Overrides backend selection, tensor conversion, and descriptor
type detection.

Parameters
Expand Down Expand Up @@ -91,6 +91,10 @@ def _get_descrpt_type(self) -> str:
return "T_TEBD"
raise RuntimeError(f"Unsupported descriptor {self.descrpt}")

def _get_math_backend_sample(self) -> Any:
"""Run shared tabulation math on the current torch device."""
return torch.empty((), dtype=torch.float64, device=env.DEVICE)

def _convert_numpy_to_tensor(self) -> None:
"""Convert self.data from np.ndarray to torch.Tensor."""
self._convert_numpy_float_to_int()
Expand Down
16 changes: 13 additions & 3 deletions deepmd/pt_expt/utils/tabulate.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
"""DPTabulate for the pt_expt backend.

Inherits the numpy math from ``deepmd.utils.tabulate_math.DPTabulate``
and overrides ``_convert_numpy_to_tensor`` for torch tensor conversion
and ``_get_descrpt_type`` for serialization-based type detection.
Inherits the shared Array API tabulation math from
``deepmd.utils.tabulate_math.DPTabulate`` and overrides backend
selection, tensor conversion, and descriptor type detection.
No dependency on the pt backend.
"""

Expand Down Expand Up @@ -70,6 +70,16 @@ def _get_descrpt_type(self) -> str:
raise RuntimeError(f"Unsupported descriptor type: {type_str}")
return descrpt_type

def _get_math_backend_sample(self) -> Any:
"""Run shared tabulation math on the pt_expt torch device."""
import torch

from deepmd.pt_expt.utils.env import (
DEVICE,
)

return torch.empty((), dtype=torch.float64, device=DEVICE)

def _convert_numpy_to_tensor(self) -> None:
"""Convert self.data from np.ndarray to torch.Tensor."""
import torch
Expand Down
Loading
Loading