Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
605 changes: 23 additions & 582 deletions deepmd/pt/utils/tabulate.py

Large diffs are not rendered by default.

5 changes: 1 addition & 4 deletions deepmd/pt_expt/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,6 @@ def enable_compression(
check_frequency
The overflow check frequency
"""
from deepmd.pt.utils.utils import (
ActivationFn,
)
from deepmd.pt_expt.utils.tabulate import (
DPTabulate,
)
Expand All @@ -71,7 +68,7 @@ def enable_compression(
data["neuron"],
data["type_one_side"],
data["exclude_types"],
ActivationFn(data["activation_function"]),
data["activation_function"],
)
self.table_config = [
table_extrapolate,
Expand Down
5 changes: 1 addition & 4 deletions deepmd/pt_expt/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,6 @@ def enable_compression(
check_frequency
The overflow check frequency
"""
from deepmd.pt.utils.utils import (
ActivationFn,
)
from deepmd.pt_expt.utils.tabulate import (
DPTabulate,
)
Expand Down Expand Up @@ -105,7 +102,7 @@ def enable_compression(
repinit_data["neuron"],
repinit_data.get("type_one_side", False),
repinit_data.get("exclude_types", []),
ActivationFn(repinit_data["activation_function"]),
repinit_data["activation_function"],
)
self.table_config = [
table_extrapolate,
Expand Down
5 changes: 1 addition & 4 deletions deepmd/pt_expt/descriptor/se_e2_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,6 @@ def enable_compression(
table_stride_2: float = 0.1,
check_frequency: int = -1,
) -> None:
from deepmd.pt.utils.utils import (
ActivationFn,
)
from deepmd.pt_expt.utils.tabulate import (
DPTabulate,
)
Expand All @@ -49,7 +46,7 @@ def enable_compression(
data["neuron"],
data["type_one_side"],
data["exclude_types"],
ActivationFn(data["activation_function"]),
data["activation_function"],
)
self.table_config = [
table_extrapolate,
Expand Down
5 changes: 1 addition & 4 deletions deepmd/pt_expt/descriptor/se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,6 @@ def enable_compression(
table_stride_2: float = 0.1,
check_frequency: int = -1,
) -> None:
from deepmd.pt.utils.utils import (
ActivationFn,
)
from deepmd.pt_expt.utils.tabulate import (
DPTabulate,
)
Expand All @@ -49,7 +46,7 @@ def enable_compression(
data["neuron"],
data["type_one_side"],
data["exclude_types"],
ActivationFn(data["activation_function"]),
data["activation_function"],
)
self.table_config = [
table_extrapolate,
Expand Down
5 changes: 1 addition & 4 deletions deepmd/pt_expt/descriptor/se_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,6 @@ def enable_compression(
table_stride_2: float = 0.1,
check_frequency: int = -1,
) -> None:
from deepmd.pt.utils.utils import (
ActivationFn,
)
from deepmd.pt_expt.utils.tabulate import (
DPTabulate,
)
Expand All @@ -49,7 +46,7 @@ def enable_compression(
self,
data["neuron"],
exclude_types=data["exclude_types"],
activation_fn=ActivationFn(data["activation_function"]),
activation_fn_name=data["activation_function"],
)
# SE_T scales strides by 10
stride_1_scaled = table_stride_1 * 10
Expand Down
5 changes: 1 addition & 4 deletions deepmd/pt_expt/descriptor/se_t_tebd.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,6 @@ def enable_compression(
check_frequency
The overflow check frequency
"""
from deepmd.pt.utils.utils import (
ActivationFn,
)
from deepmd.pt_expt.utils.tabulate import (
DPTabulate,
)
Expand All @@ -65,7 +62,7 @@ def enable_compression(
self,
data["neuron"],
exclude_types=data["exclude_types"],
activation_fn=ActivationFn(data["activation_function"]),
activation_fn_name=data["activation_function"],
)
# SE_T scales strides by 10
stride_1_scaled = table_stride_1 * 10
Expand Down
52 changes: 29 additions & 23 deletions deepmd/pt_expt/utils/tabulate.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,20 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
"""DPTabulate for the pt_expt backend.

Subclasses the pt backend's DPTabulate, overriding _get_descrpt_type() to
detect descriptor types via serialized data rather than isinstance checks
against pt-specific classes.
Inherits the numpy math from ``deepmd.utils.tabulate_math.DPTabulate``
and overrides ``_convert_numpy_to_tensor`` for torch tensor conversion
and ``_get_descrpt_type`` for serialization-based type detection.
No dependency on the pt backend.
"""

from typing import (
Any,
)

from deepmd.pt.utils.tabulate import DPTabulate as DPTabulatePT
from deepmd.pt.utils.utils import (
ActivationFn,
)
from deepmd.utils.tabulate_math import DPTabulate as DPTabulateBase


class DPTabulate(DPTabulatePT):
class DPTabulate(DPTabulateBase):
"""Tabulation helper for pt_expt descriptors.

The descriptor passed to this class must serialize to a dict with
Expand All @@ -34,8 +32,8 @@ class DPTabulate(DPTabulatePT):
Whether to use one-side type embedding.
exclude_types
Excluded type pairs.
activation_fn
The activation function used in the embedding net.
activation_fn_name
Name of the activation function (e.g. "tanh", "gelu").
"""

def __init__(
Expand All @@ -44,23 +42,20 @@ def __init__(
neuron: list[int],
type_one_side: bool = False,
exclude_types: list[list[int]] = [],
activation_fn: ActivationFn = ActivationFn("tanh"),
activation_fn_name: str = "tanh",
) -> None:
# DPTabulatePT.__init__ works here because:
# 1. _get_descrpt_type is overridden to use serialized data (not isinstance)
# 2. The isinstance(descrpt, DescrptDPA2) check in parent just returns False
# for pt_expt descriptors — callers pass the repinit block directly.
super().__init__(descrpt, neuron, type_one_side, exclude_types, activation_fn)
super().__init__(
descrpt,
neuron,
type_one_side,
exclude_types,
activation_fn_name=activation_fn_name,
)

def _get_descrpt_type(self) -> str:
"""Determine descriptor type from serialized data.

Instead of isinstance checks against pt classes, use the "type" key
from the serialized descriptor dict.
"""
"""Determine descriptor type from serialized data."""
data = self.descrpt.serialize()
type_str = data.get("type", "")

type_map = {
"se_e2_a": "A",
"se_r": "R",
Expand All @@ -69,8 +64,19 @@ def _get_descrpt_type(self) -> str:
"dpa1": "Atten",
"se_atten_v2": "Atten",
}

descrpt_type = type_map.get(type_str)
if descrpt_type is None:
raise RuntimeError(f"Unsupported descriptor type: {type_str}")
return descrpt_type

def _convert_numpy_to_tensor(self) -> None:
"""Convert self.data from np.ndarray to torch.Tensor."""
import torch

from deepmd.pt_expt.utils.env import (
DEVICE,
)

self._convert_numpy_float_to_int()
for ii in self.data:
self.data[ii] = torch.tensor(self.data[ii], device=DEVICE) # pylint: disable=no-explicit-dtype
1 change: 0 additions & 1 deletion deepmd/tf/utils/tabulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ def __init__(
neuron,
type_one_side,
exclude_types,
False,
)

self.descrpt_type = self._get_descrpt_type()
Expand Down
36 changes: 12 additions & 24 deletions deepmd/utils/tabulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ def __init__(
neuron: list[int],
type_one_side: bool,
exclude_types: list[list[int]],
is_pt: bool,
) -> None:
"""Constructor."""
super().__init__()
Expand All @@ -38,7 +37,6 @@ def __init__(
self.neuron = neuron
self.type_one_side = type_one_side
self.exclude_types = exclude_types
self.is_pt = is_pt

"""Need to be initialized in the subclass."""
self.descrpt_type = "Base"
Expand Down Expand Up @@ -91,6 +89,11 @@ def build(
"""
# tabulate range [lower, upper] with stride0 'stride0'
lower, upper = self._get_env_mat_range(min_nbor_dist)
# Normalize to per-type scalars: PT serialized data produces
# multi-dimensional arrays (ntypes, nnei) while TF produces 1D.
if lower.ndim > 1:
lower = np.min(lower, axis=tuple(range(1, lower.ndim)))
upper = np.max(upper, axis=tuple(range(1, upper.ndim)))
if self.descrpt_type in ("Atten", "AEbdV2"):
uu = np.max(upper)
ll = np.min(lower)
Expand Down Expand Up @@ -127,12 +130,8 @@ def build(
net = (
"filter_" + str(ielement) + "_net_" + str(ii % self.ntypes)
)
if self.is_pt:
uu = np.max(upper[ielement])
ll = np.min(lower[ielement])
else:
uu = upper[ielement]
ll = lower[ielement]
uu = upper[ielement]
ll = lower[ielement]
xx = np.arange(ll, uu, stride0, dtype=self.data_type)
xx = np.append(
xx,
Expand All @@ -150,13 +149,8 @@ def build(
elif self.descrpt_type == "T":
xx_all = []
for ii in range(self.ntypes):
"""Pt and tf is different here. Pt version is a two-dimensional array."""
if self.is_pt:
uu = np.max(upper[ii])
ll = np.min(lower[ii])
else:
ll = lower[ii]
uu = upper[ii]
ll = lower[ii]
uu = upper[ii]
xx = np.arange(extrapolate * ll, ll, stride1, dtype=self.data_type)
xx = np.append(xx, np.arange(ll, uu, stride0, dtype=self.data_type))
xx = np.append(
Expand All @@ -176,12 +170,8 @@ def build(
).astype(int)
idx = 0
for ii in range(self.ntypes):
if self.is_pt:
uu = np.max(upper[ii])
ll = np.min(lower[ii])
else:
ll = lower[ii]
uu = upper[ii]
ll = lower[ii]
uu = upper[ii]
for jj in range(ii, self.ntypes):
net = "filter_" + str(ii) + "_net_" + str(jj)
self._build_lower(
Expand All @@ -193,7 +183,7 @@ def build(
stride0,
stride1,
extrapolate,
nspline[ii][0] if self.is_pt else nspline[ii],
nspline[ii],
)
idx += 1
elif self.descrpt_type == "T_TEBD":
Expand Down Expand Up @@ -279,8 +269,6 @@ def build(
raise RuntimeError("Unsupported descriptor")

self._convert_numpy_to_tensor()
if self.is_pt:
self._convert_numpy_float_to_int()
return self.lower, self.upper

# generate_spline_table
Expand Down
Loading