Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
ec2e031
implement pytorch-exportable for se_e2_a descriptor
Feb 5, 2026
b8a48ff
better type for xp.zeros
Feb 5, 2026
1cc001f
implement env, base_descriptor and exclude_mask, remove the dependenc…
Feb 6, 2026
f2fbe88
mv to_torch_tensor to common
Feb 6, 2026
e2afbe9
simplify __init__ of the NaiveLayer
Feb 6, 2026
4ba511a
fix bug
Feb 6, 2026
fb9598a
fix bug
Feb 6, 2026
fa03351
simplify init method of se_e2_a descriptor. fig bug in consistent UT
Feb 6, 2026
09b33f1
restructure the test folders. add test_common.
Feb 6, 2026
67f2e54
add test_exclusion_mask.py
Feb 6, 2026
f7d83dd
fix poitential import issue in test.
Feb 6, 2026
0c96bb6
correct __call__(). fix bug
Feb 6, 2026
9dca912
fix registration issue
Feb 6, 2026
17f0a5d
fix pt-expt file extension
Feb 6, 2026
8ce93ba
fix(pt): expansion of get_default_nthreads()
Feb 6, 2026
3091988
fix bug of intra-inter
Feb 6, 2026
85f0583
fix bug of default dp inter value
Feb 6, 2026
d33324d
fix cicd
Feb 6, 2026
4de9a56
feat: add support for se_r
Feb 6, 2026
f4dc0af
fix device of xp array
Feb 6, 2026
2384835
fix device of xp array
Feb 6, 2026
9646d71
revert extend_coord_with_ghosts
Feb 6, 2026
f270069
raise error for non-implemented methods
Feb 6, 2026
57433d3
restore import torch
Feb 6, 2026
eedcbaf
fix(pt,pt-expt): guard thread setters
Feb 6, 2026
d8b2cf4
make exclusion mask modules
Feb 6, 2026
aeef15a
fix(pt-expt): clear params on None
Feb 6, 2026
8bdb1f8
fix bug
Feb 7, 2026
d3b01da
utility to handel dpmodel -> pt_expt conversion
Feb 8, 2026
3452a2a
fix to_numpy_array device
Feb 8, 2026
ba8e7ab
chore(dpmodel,pt_expt): refactorize the implementation of embedding net
Feb 8, 2026
621c7cc
feat: se_t and se_t_tebd descriptors for the pytroch exportable backend.
Feb 8, 2026
faa4026
fix bug
Feb 8, 2026
8c63762
fix bug
Feb 8, 2026
ae58734
merge with master
Feb 8, 2026
222cd6a
Revert "feat: se_t and se_t_tebd descriptors for the pytroch exportab…
Feb 8, 2026
dcd7df4
Merge branch 'master' into refact-embed-net
njzjz Feb 9, 2026
249627a
fix case neuron=[]
Feb 10, 2026
6ea965b
fix issue of ncpu may be None
Feb 10, 2026
073d12c
fix bug of device
Feb 10, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 109 additions & 1 deletion deepmd/dpmodel/utils/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,7 +785,115 @@ def deserialize(cls, data: dict) -> "EmbeddingNet":
return EN


EmbeddingNet = make_embedding_network(NativeNet, NativeLayer)
class EmbeddingNet(NativeNet):
"""The embedding network.

Parameters
----------
in_dim
Input dimension.
neuron
The number of neurons in each layer. The output dimension
is the same as the dimension of the last layer.
activation_function
The activation function.
resnet_dt
Use time step at the resnet architecture.
precision
Floating point precision for the model parameters.
seed : int, optional
Random seed.
bias : bool, Optional
Whether to use bias in the embedding layer.
trainable : bool or list[bool], Optional
Whether the weights are trainable. If a list, each element
corresponds to a layer.
"""

def __init__(
self,
in_dim: int,
neuron: list[int] = [24, 48, 96],
activation_function: str = "tanh",
resnet_dt: bool = False,
precision: str = DEFAULT_PRECISION,
seed: int | list[int] | None = None,
bias: bool = True,
trainable: bool | list[bool] = True,
) -> None:
layers = []
i_in = in_dim
if isinstance(trainable, bool):
trainable = [trainable] * len(neuron)
for idx, ii in enumerate(neuron):
i_ot = ii
layers.append(
NativeLayer(
i_in,
i_ot,
bias=bias,
use_timestep=resnet_dt,
activation_function=activation_function,
resnet=True,
precision=precision,
seed=child_seed(seed, idx),
trainable=trainable[idx],
).serialize()
)
i_in = i_ot
super().__init__(layers)
self.in_dim = in_dim
self.neuron = neuron
self.activation_function = activation_function
self.resnet_dt = resnet_dt
self.precision = precision
self.bias = bias

def serialize(self) -> dict:
"""Serialize the network to a dict.

Returns
-------
dict
The serialized network.
"""
return {
"@class": "EmbeddingNetwork",
"@version": 2,
"in_dim": self.in_dim,
"neuron": self.neuron.copy(),
"activation_function": self.activation_function,
"resnet_dt": self.resnet_dt,
"bias": self.bias,
# make deterministic
"precision": np.dtype(PRECISION_DICT[self.precision]).name,
"layers": [layer.serialize() for layer in self.layers],
}

@classmethod
def deserialize(cls, data: dict) -> "EmbeddingNet":
"""Deserialize the network from a dict.

Parameters
----------
data : dict
The dict to deserialize from.
"""
data = data.copy()
check_version_compatibility(data.pop("@version", 1), 2, 1)
data.pop("@class", None)
layers = data.pop("layers")
obj = cls(**data)
# Reinitialize layers from serialized data, using the same layer type
# that __init__ created (respects subclass overrides via MRO).
if obj.layers:
layer_type = type(obj.layers[0])
obj.layers = type(obj.layers)(
[layer_type.deserialize(layer) for layer in layers]
)
else:
obj.layers = type(obj.layers)([])
return obj
Comment thread
coderabbitai[bot] marked this conversation as resolved.


def make_fitting_network(
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/utils/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
# only linux
ncpus = len(os.sched_getaffinity(0))
except AttributeError:
ncpus = os.cpu_count()
ncpus = os.cpu_count() or 1
NUM_WORKERS = int(os.environ.get("NUM_WORKERS", min(4, ncpus)))
if multiprocessing.get_start_method() != "fork":
# spawn or forkserver does not support NUM_WORKERS > 0 for DataLoader
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt_expt/utils/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
# only linux
ncpus = len(os.sched_getaffinity(0))
except AttributeError:
ncpus = os.cpu_count()
ncpus = os.cpu_count() or 1
NUM_WORKERS = int(os.environ.get("NUM_WORKERS", min(4, ncpus)))
if multiprocessing.get_start_method() != "fork":
# spawn or forkserver does not support NUM_WORKERS > 0 for DataLoader
Expand Down
25 changes: 22 additions & 3 deletions deepmd/pt_expt/utils/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
from deepmd.dpmodel.common import (
NativeOP,
)
from deepmd.dpmodel.utils.network import EmbeddingNet as EmbeddingNetDP
from deepmd.dpmodel.utils.network import LayerNorm as LayerNormDP
from deepmd.dpmodel.utils.network import NativeLayer as NativeLayerDP
from deepmd.dpmodel.utils.network import NetworkCollection as NetworkCollectionDP
from deepmd.dpmodel.utils.network import (
make_embedding_network,
make_fitting_network,
make_multilayer_network,
)
Expand Down Expand Up @@ -91,8 +91,27 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.call(x)


class EmbeddingNet(make_embedding_network(NativeNet, NativeLayer)):
pass
class EmbeddingNet(EmbeddingNetDP, torch.nn.Module):
def __init__(self, *args: Any, **kwargs: Any) -> None:
torch.nn.Module.__init__(self)
EmbeddingNetDP.__init__(self, *args, **kwargs)
# EmbeddingNetDP.__init__ creates dpmodel NativeLayer instances.
# Convert to pt_expt NativeLayer and wrap in ModuleList.
self.layers = torch.nn.ModuleList(
[NativeLayer.deserialize(layer.serialize()) for layer in self.layers]
)

def __call__(self, *args: Any, **kwargs: Any) -> Any:
return torch.nn.Module.__call__(self, *args, **kwargs)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.call(x)


register_dpmodel_mapping(
EmbeddingNetDP,
lambda v: EmbeddingNet.deserialize(v.serialize()),
)


class FittingNet(make_fitting_network(EmbeddingNet, NativeNet, NativeLayer)):
Expand Down
108 changes: 108 additions & 0 deletions source/tests/common/dpmodel/test_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,114 @@ def test_embedding_net(self) -> None:
inp = np.ones([ni], dtype=get_xp_precision(np, prec))
np.testing.assert_allclose(en0.call(inp), en1.call(inp))

def test_is_concrete_class(self) -> None:
"""Verify EmbeddingNet is a concrete class, not factory-generated."""
in_dim = 4
neuron = [8, 16, 32]
net = EmbeddingNet(
in_dim=in_dim,
neuron=neuron,
activation_function="tanh",
resnet_dt=True,
precision="float64",
)
# Check it's the actual EmbeddingNet class, not a dynamic class
self.assertEqual(net.__class__.__name__, "EmbeddingNet")
self.assertEqual(net.__class__.__module__, "deepmd.dpmodel.utils.network")
# Verify it has the expected attributes
self.assertEqual(net.in_dim, in_dim)
self.assertEqual(net.neuron, neuron)
self.assertEqual(net.activation_function, "tanh")
self.assertEqual(net.resnet_dt, True)
self.assertEqual(len(net.layers), len(neuron))

def test_forward_pass(self) -> None:
"""Test EmbeddingNet forward pass produces correct shapes."""
in_dim = 4
neuron = [8, 16, 32]
net = EmbeddingNet(
in_dim=in_dim,
neuron=neuron,
activation_function="tanh",
resnet_dt=True,
precision="float64",
)
rng = np.random.default_rng()
x = rng.standard_normal((5, in_dim))
out = net.call(x)
self.assertEqual(out.shape, (5, neuron[-1]))
self.assertEqual(out.dtype, np.float64)

def test_trainable_parameter_variants(self) -> None:
"""Test EmbeddingNet with different trainable configurations."""
in_dim = 4
neuron = [8, 16]

# All trainable
net_trainable = EmbeddingNet(
in_dim=in_dim,
neuron=neuron,
trainable=True,
)
for layer in net_trainable.layers:
self.assertTrue(layer.trainable)

# All frozen
net_frozen = EmbeddingNet(
in_dim=in_dim,
neuron=neuron,
trainable=False,
)
for layer in net_frozen.layers:
self.assertFalse(layer.trainable)

# Mixed trainable
net_mixed = EmbeddingNet(
in_dim=in_dim,
neuron=neuron,
trainable=[True, False],
)
self.assertTrue(net_mixed.layers[0].trainable)
self.assertFalse(net_mixed.layers[1].trainable)

def test_empty_layers_round_trip(self) -> None:
"""Test EmbeddingNet with empty neuron list (edge case for deserialize).

This tests the fix for IndexError when neuron=[] results in empty layers.
The deserialize method should handle this case without trying to access
layers[0] when the list is empty.
"""
in_dim = 4
neuron = [] # Empty neuron list

# Create network with empty layers
net = EmbeddingNet(
in_dim=in_dim,
neuron=neuron,
activation_function="tanh",
resnet_dt=True,
precision="float64",
)

# Verify it has no layers
self.assertEqual(len(net.layers), 0)

# Serialize and deserialize
serialized = net.serialize()
net_restored = EmbeddingNet.deserialize(serialized)

# Verify restored network also has no layers
self.assertEqual(len(net_restored.layers), 0)
self.assertEqual(net_restored.in_dim, in_dim)
self.assertEqual(net_restored.neuron, neuron)

# Verify forward pass works (should return input unchanged)
rng = np.random.default_rng()
x = rng.standard_normal((5, in_dim))
out = net_restored.call(x)
# With no layers, output should equal input
np.testing.assert_allclose(out, x)


class TestFittingNet(unittest.TestCase):
def test_fitting_net(self) -> None:
Expand Down
Loading