Skip to content

Commit ba8e7ab

Browse files
author
Han Wang
committed
chore(dpmodel,pt_expt): refactorize the implementation of embedding net
1 parent 3452a2a commit ba8e7ab

4 files changed

Lines changed: 457 additions & 4 deletions

File tree

deepmd/dpmodel/utils/network.py

Lines changed: 106 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -788,7 +788,112 @@ def deserialize(cls, data: dict) -> "EmbeddingNet":
788788
return EN
789789

790790

791-
EmbeddingNet = make_embedding_network(NativeNet, NativeLayer)
791+
class EmbeddingNet(NativeNet):
792+
"""The embedding network.
793+
794+
Parameters
795+
----------
796+
in_dim
797+
Input dimension.
798+
neuron
799+
The number of neurons in each layer. The output dimension
800+
is the same as the dimension of the last layer.
801+
activation_function
802+
The activation function.
803+
resnet_dt
804+
Use time step at the resnet architecture.
805+
precision
806+
Floating point precision for the model parameters.
807+
seed : int, optional
808+
Random seed.
809+
bias : bool, Optional
810+
Whether to use bias in the embedding layer.
811+
trainable : bool or list[bool], Optional
812+
Whether the weights are trainable. If a list, each element
813+
corresponds to a layer.
814+
"""
815+
816+
def __init__(
817+
self,
818+
in_dim: int,
819+
neuron: list[int] = [24, 48, 96],
820+
activation_function: str = "tanh",
821+
resnet_dt: bool = False,
822+
precision: str = DEFAULT_PRECISION,
823+
seed: int | list[int] | None = None,
824+
bias: bool = True,
825+
trainable: bool | list[bool] = True,
826+
) -> None:
827+
layers = []
828+
i_in = in_dim
829+
if isinstance(trainable, bool):
830+
trainable = [trainable] * len(neuron)
831+
for idx, ii in enumerate(neuron):
832+
i_ot = ii
833+
layers.append(
834+
NativeLayer(
835+
i_in,
836+
i_ot,
837+
bias=bias,
838+
use_timestep=resnet_dt,
839+
activation_function=activation_function,
840+
resnet=True,
841+
precision=precision,
842+
seed=child_seed(seed, idx),
843+
trainable=trainable[idx],
844+
).serialize()
845+
)
846+
i_in = i_ot
847+
super().__init__(layers)
848+
self.in_dim = in_dim
849+
self.neuron = neuron
850+
self.activation_function = activation_function
851+
self.resnet_dt = resnet_dt
852+
self.precision = precision
853+
self.bias = bias
854+
855+
def serialize(self) -> dict:
856+
"""Serialize the network to a dict.
857+
858+
Returns
859+
-------
860+
dict
861+
The serialized network.
862+
"""
863+
return {
864+
"@class": "EmbeddingNetwork",
865+
"@version": 2,
866+
"in_dim": self.in_dim,
867+
"neuron": self.neuron.copy(),
868+
"activation_function": self.activation_function,
869+
"resnet_dt": self.resnet_dt,
870+
"bias": self.bias,
871+
# make deterministic
872+
"precision": np.dtype(PRECISION_DICT[self.precision]).name,
873+
"layers": [layer.serialize() for layer in self.layers],
874+
}
875+
876+
@classmethod
877+
def deserialize(cls, data: dict) -> "EmbeddingNet":
878+
"""Deserialize the network from a dict.
879+
880+
Parameters
881+
----------
882+
data : dict
883+
The dict to deserialize from.
884+
"""
885+
data = data.copy()
886+
check_version_compatibility(data.pop("@version", 1), 2, 1)
887+
data.pop("@class", None)
888+
layers = data.pop("layers")
889+
obj = cls(**data)
890+
# Reinitialize layers from serialized data, using the same layer type
891+
# that __init__ created (respects subclass overrides via MRO).
892+
layer_type = type(obj.layers[0])
893+
obj.layers = type(obj.layers)(
894+
[layer_type.deserialize(layer) for layer in layers]
895+
)
896+
return obj
792897

793898

794899
def make_fitting_network(

deepmd/pt_expt/utils/network.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@
1010
from deepmd.dpmodel.common import (
1111
NativeOP,
1212
)
13+
from deepmd.dpmodel.utils.network import EmbeddingNet as EmbeddingNetDP
1314
from deepmd.dpmodel.utils.network import LayerNorm as LayerNormDP
1415
from deepmd.dpmodel.utils.network import NativeLayer as NativeLayerDP
1516
from deepmd.dpmodel.utils.network import NetworkCollection as NetworkCollectionDP
1617
from deepmd.dpmodel.utils.network import (
17-
make_embedding_network,
1818
make_fitting_network,
1919
make_multilayer_network,
2020
)
@@ -91,8 +91,27 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
9191
return self.call(x)
9292

9393

94-
class EmbeddingNet(make_embedding_network(NativeNet, NativeLayer)):
95-
pass
94+
class EmbeddingNet(EmbeddingNetDP, torch.nn.Module):
95+
def __init__(self, *args: Any, **kwargs: Any) -> None:
96+
torch.nn.Module.__init__(self)
97+
EmbeddingNetDP.__init__(self, *args, **kwargs)
98+
# EmbeddingNetDP.__init__ creates dpmodel NativeLayer instances.
99+
# Convert to pt_expt NativeLayer and wrap in ModuleList.
100+
self.layers = torch.nn.ModuleList(
101+
[NativeLayer.deserialize(layer.serialize()) for layer in self.layers]
102+
)
103+
104+
def __call__(self, *args: Any, **kwargs: Any) -> Any:
105+
return torch.nn.Module.__call__(self, *args, **kwargs)
106+
107+
def forward(self, x: torch.Tensor) -> torch.Tensor:
108+
return self.call(x)
109+
110+
111+
register_dpmodel_mapping(
112+
EmbeddingNetDP,
113+
lambda v: EmbeddingNet.deserialize(v.serialize()),
114+
)
96115

97116

98117
class FittingNet(make_fitting_network(EmbeddingNet, NativeNet, NativeLayer)):

source/tests/common/dpmodel/test_network.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,76 @@ def test_embedding_net(self) -> None:
180180
inp = np.ones([ni], dtype=get_xp_precision(np, prec))
181181
np.testing.assert_allclose(en0.call(inp), en1.call(inp))
182182

183+
def test_is_concrete_class(self) -> None:
184+
"""Verify EmbeddingNet is a concrete class, not factory-generated."""
185+
in_dim = 4
186+
neuron = [8, 16, 32]
187+
net = EmbeddingNet(
188+
in_dim=in_dim,
189+
neuron=neuron,
190+
activation_function="tanh",
191+
resnet_dt=True,
192+
precision="float64",
193+
)
194+
# Check it's the actual EmbeddingNet class, not a dynamic class
195+
self.assertEqual(net.__class__.__name__, "EmbeddingNet")
196+
self.assertEqual(net.__class__.__module__, "deepmd.dpmodel.utils.network")
197+
# Verify it has the expected attributes
198+
self.assertEqual(net.in_dim, in_dim)
199+
self.assertEqual(net.neuron, neuron)
200+
self.assertEqual(net.activation_function, "tanh")
201+
self.assertEqual(net.resnet_dt, True)
202+
self.assertEqual(len(net.layers), len(neuron))
203+
204+
def test_forward_pass(self) -> None:
205+
"""Test EmbeddingNet forward pass produces correct shapes."""
206+
in_dim = 4
207+
neuron = [8, 16, 32]
208+
net = EmbeddingNet(
209+
in_dim=in_dim,
210+
neuron=neuron,
211+
activation_function="tanh",
212+
resnet_dt=True,
213+
precision="float64",
214+
)
215+
rng = np.random.default_rng()
216+
x = rng.standard_normal((5, in_dim))
217+
out = net.call(x)
218+
self.assertEqual(out.shape, (5, neuron[-1]))
219+
self.assertEqual(out.dtype, np.float64)
220+
221+
def test_trainable_parameter_variants(self) -> None:
222+
"""Test EmbeddingNet with different trainable configurations."""
223+
in_dim = 4
224+
neuron = [8, 16]
225+
226+
# All trainable
227+
net_trainable = EmbeddingNet(
228+
in_dim=in_dim,
229+
neuron=neuron,
230+
trainable=True,
231+
)
232+
for layer in net_trainable.layers:
233+
self.assertTrue(layer.trainable)
234+
235+
# All frozen
236+
net_frozen = EmbeddingNet(
237+
in_dim=in_dim,
238+
neuron=neuron,
239+
trainable=False,
240+
)
241+
for layer in net_frozen.layers:
242+
self.assertFalse(layer.trainable)
243+
244+
# Mixed trainable
245+
net_mixed = EmbeddingNet(
246+
in_dim=in_dim,
247+
neuron=neuron,
248+
trainable=[True, False],
249+
)
250+
self.assertTrue(net_mixed.layers[0].trainable)
251+
self.assertFalse(net_mixed.layers[1].trainable)
252+
183253

184254
class TestFittingNet(unittest.TestCase):
185255
def test_fitting_net(self) -> None:

0 commit comments

Comments
 (0)