Skip to content

Commit cd67bbe

Browse files
wanghan-iapcmHan Wangnjzjz
authored
refact(dpmodel,pt_expt): embedding net (#5205)
# EmbeddingNet Refactoring: Factory Function to Concrete Class ## Summary This refactoring converts `EmbeddingNet` from a factory-generated dynamic class to a concrete class in the dpmodel backend. This change enables the auto-detection registry mechanism in pt_expt to work seamlessly with EmbeddingNet attributes. This PR is considered after #5194 and #5204 ## Motivation **Before**: `EmbeddingNet` was created by a factory function `make_embedding_network(NativeNet, NativeLayer)`, producing a dynamically-typed class `make_embedding_network.<locals>.EN`. This caused two problems: 1. **Cannot be registered**: Dynamic classes can't be imported or registered at module import time in the pt_expt registry 2. **Name-based hacks required**: pt_expt wrappers had to explicitly check for `name == "embedding_net"` in `__setattr__` instead of using the type-based auto-detection mechanism **After**: `EmbeddingNet` is now a concrete class that can be registered in the pt_expt auto-conversion registry, eliminating the need for name-based special cases. ## Changes ### 1. dpmodel: Concrete `EmbeddingNet` class **File**: `deepmd/dpmodel/utils/network.py` - Replaced factory-generated class with concrete `EmbeddingNet(NativeNet)` class - Moved constructor logic from factory into `__init__` - Fixed `deserialize` to use `type(obj.layers[0])` instead of hardcoding `super(EmbeddingNet, obj)`, allowing pt_expt subclass to preserve its converted torch layers - Kept `make_embedding_network` factory for pt/pd backends that use different base classes (MLP) ```python class EmbeddingNet(NativeNet): """The embedding network.""" def __init__(self, in_dim, neuron=[24, 48, 96], activation_function="tanh", resnet_dt=False, precision=DEFAULT_PRECISION, seed=None, bias=True, trainable=True): layers = [] i_in = in_dim if isinstance(trainable, bool): trainable = [trainable] * len(neuron) for idx, ii in enumerate(neuron): i_ot = ii layers.append( NativeLayer( i_in, i_ot, bias=bias, use_timestep=resnet_dt, activation_function=activation_function, resnet=True, precision=precision, seed=child_seed(seed, idx), trainable=trainable[idx] ).serialize() ) i_in = i_ot super().__init__(layers) self.in_dim = in_dim self.neuron = neuron self.activation_function = activation_function self.resnet_dt = resnet_dt self.precision = precision self.bias = bias @classmethod def deserialize(cls, data): data = data.copy() check_version_compatibility(data.pop("@Version", 1), 2, 1) data.pop("@Class", None) layers = data.pop("layers") obj = cls(**data) # Use type(obj.layers[0]) to respect subclass layer types layer_type = type(obj.layers[0]) obj.layers = type(obj.layers)( [layer_type.deserialize(layer) for layer in layers] ) return obj ``` ### 2. pt_expt: Wrapper and registration **File**: `deepmd/pt_expt/utils/network.py` - Created `EmbeddingNet(EmbeddingNetDP, torch.nn.Module)` wrapper - Converts dpmodel layers to pt_expt `NativeLayer` (torch modules) in `__init__` - Registered in auto-conversion registry ```python class EmbeddingNet(EmbeddingNetDP, torch.nn.Module): def __init__(self, *args: Any, **kwargs: Any) -> None: torch.nn.Module.__init__(self) EmbeddingNetDP.__init__(self, *args, **kwargs) # Convert dpmodel layers to pt_expt NativeLayer self.layers = torch.nn.ModuleList( [NativeLayer.deserialize(layer.serialize()) for layer in self.layers] ) def __call__(self, *args: Any, **kwargs: Any) -> Any: return torch.nn.Module.__call__(self, *args, **kwargs) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.call(x) register_dpmodel_mapping( EmbeddingNetDP, lambda v: EmbeddingNet.deserialize(v.serialize()), ) ``` ### 3. TypeEmbedNet: Simplified to use registry **File**: `deepmd/pt_expt/utils/type_embed.py` - No longer needs name-based `embedding_net` check in `__setattr__` - Uses common `dpmodel_setattr` which auto-converts via registry - Imports `network` module to ensure `EmbeddingNet` registration happens first ```python class TypeEmbedNet(TypeEmbedNetDP, torch.nn.Module): def __setattr__(self, name: str, value: Any) -> None: # Auto-converts embedding_net via registry handled, value = dpmodel_setattr(self, name, value) if not handled: super().__setattr__(name, value) ``` ## Tests ### dpmodel tests **File**: `source/tests/common/dpmodel/test_network.py` Added to `TestEmbeddingNet` class: 1. **`test_is_concrete_class`**: Verifies `EmbeddingNet` is now a concrete class, not factory output 2. **`test_forward_pass`**: Tests dpmodel forward pass produces correct shapes 3. **`test_trainable_parameter_variants`**: Tests different trainable configurations (all trainable, all frozen, mixed) (The existing `test_embedding_net` test already covers serialization/deserialization round-trip) ### pt_expt integration tests **File**: `source/tests/pt_expt/utils/test_network.py` Created `TestEmbeddingNetRefactor` test suite with 8 tests: 1. **`test_pt_expt_embedding_net_wraps_dpmodel`**: Verifies pt_expt wrapper inherits correctly and converts layers 2. **`test_pt_expt_embedding_net_forward`**: Tests pt_expt forward pass returns torch.Tensor 3. **`test_serialization_round_trip_pt_expt`**: Tests pt_expt serialize/deserialize 4. **`test_deserialize_preserves_layer_type`**: Tests the key fix - `deserialize` uses `type(obj.layers[0])` to preserve pt_expt's torch layers 5. **`test_cross_backend_consistency`**: Tests numerical consistency between dpmodel and pt_expt 6. **`test_registry_converts_dpmodel_to_pt_expt`**: Tests `try_convert_module` auto-converts dpmodel to pt_expt 7. **`test_auto_conversion_in_setattr`**: Tests `dpmodel_setattr` auto-converts EmbeddingNet attributes 8. **`test_trainable_parameter_handling`**: Tests trainable vs frozen parameters work correctly in pt_expt ## Verification All tests pass: ```bash # dpmodel EmbeddingNet tests python -m pytest source/tests/common/dpmodel/test_network.py::TestEmbeddingNet -v # 4 passed in 0.41s # pt_expt EmbeddingNet integration tests python -m pytest source/tests/pt_expt/utils/test_network.py::TestEmbeddingNetRefactor -v # 8 passed in 0.41s # All pt_expt network tests python -m pytest source/tests/pt_expt/utils/test_network.py -v # 10 passed in 0.41s # Descriptor tests (verify refactoring doesn't break existing code) python -m pytest source/tests/pt_expt/descriptor/test_se_e2_a.py -v -k consistency # 1 passed python -m pytest source/tests/universal/pt_expt/descriptor/test_descriptor.py -v # 8 passed in 3.27s ``` ## Benefits 1. **Type-based auto-detection**: No more name-based special cases in `__setattr__` 2. **Maintainability**: Single source of truth for EmbeddingNet in dpmodel 3. **Consistency**: Same pattern as other dpmodel classes (AtomExcludeMask, NetworkCollection, etc.) 4. **Future-proof**: New attributes in dpmodel automatically work in pt_expt via registry ## Backward Compatibility - Serialization format unchanged (version 2.1) - All existing tests pass - `make_embedding_network` factory kept for pt/pd backends - No changes to public API ## Files Changed ### Modified - `deepmd/dpmodel/utils/network.py`: Concrete EmbeddingNet class + deserialize fix - `deepmd/pt_expt/utils/network.py`: EmbeddingNet wrapper + registration - `deepmd/pt_expt/utils/type_embed.py`: Simplified to use registry - `source/tests/common/dpmodel/test_network.py`: Added dpmodel EmbeddingNet tests (3 new tests) - `source/tests/pt_expt/utils/test_network.py`: Added pt_expt integration tests (8 new tests) ### No changes required - All descriptor wrappers (se_e2_a, se_r, se_t, se_t_tebd) automatically work via registry - No changes to dpmodel logic or array_api_compat code <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes * **New Features** * Added PyTorch compatibility layer enabling DPModel neural network components to be used with PyTorch workflows for training and inference * Enhanced embedding network with explicit serialization and deserialization capabilities * **Refactor** * Restructured embedding network with explicit class design for improved type stability and control flow management <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn> Co-authored-by: Han Wang <wang_han@iapcm.ac.cn> Co-authored-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
1 parent d310532 commit cd67bbe

6 files changed

Lines changed: 503 additions & 6 deletions

File tree

deepmd/dpmodel/utils/network.py

Lines changed: 109 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -785,7 +785,115 @@ def deserialize(cls, data: dict) -> "EmbeddingNet":
785785
return EN
786786

787787

788-
EmbeddingNet = make_embedding_network(NativeNet, NativeLayer)
788+
class EmbeddingNet(NativeNet):
789+
"""The embedding network.
790+
791+
Parameters
792+
----------
793+
in_dim
794+
Input dimension.
795+
neuron
796+
The number of neurons in each layer. The output dimension
797+
is the same as the dimension of the last layer.
798+
activation_function
799+
The activation function.
800+
resnet_dt
801+
Use time step at the resnet architecture.
802+
precision
803+
Floating point precision for the model parameters.
804+
seed : int, optional
805+
Random seed.
806+
bias : bool, Optional
807+
Whether to use bias in the embedding layer.
808+
trainable : bool or list[bool], Optional
809+
Whether the weights are trainable. If a list, each element
810+
corresponds to a layer.
811+
"""
812+
813+
def __init__(
814+
self,
815+
in_dim: int,
816+
neuron: list[int] = [24, 48, 96],
817+
activation_function: str = "tanh",
818+
resnet_dt: bool = False,
819+
precision: str = DEFAULT_PRECISION,
820+
seed: int | list[int] | None = None,
821+
bias: bool = True,
822+
trainable: bool | list[bool] = True,
823+
) -> None:
824+
layers = []
825+
i_in = in_dim
826+
if isinstance(trainable, bool):
827+
trainable = [trainable] * len(neuron)
828+
for idx, ii in enumerate(neuron):
829+
i_ot = ii
830+
layers.append(
831+
NativeLayer(
832+
i_in,
833+
i_ot,
834+
bias=bias,
835+
use_timestep=resnet_dt,
836+
activation_function=activation_function,
837+
resnet=True,
838+
precision=precision,
839+
seed=child_seed(seed, idx),
840+
trainable=trainable[idx],
841+
).serialize()
842+
)
843+
i_in = i_ot
844+
super().__init__(layers)
845+
self.in_dim = in_dim
846+
self.neuron = neuron
847+
self.activation_function = activation_function
848+
self.resnet_dt = resnet_dt
849+
self.precision = precision
850+
self.bias = bias
851+
852+
def serialize(self) -> dict:
853+
"""Serialize the network to a dict.
854+
855+
Returns
856+
-------
857+
dict
858+
The serialized network.
859+
"""
860+
return {
861+
"@class": "EmbeddingNetwork",
862+
"@version": 2,
863+
"in_dim": self.in_dim,
864+
"neuron": self.neuron.copy(),
865+
"activation_function": self.activation_function,
866+
"resnet_dt": self.resnet_dt,
867+
"bias": self.bias,
868+
# make deterministic
869+
"precision": np.dtype(PRECISION_DICT[self.precision]).name,
870+
"layers": [layer.serialize() for layer in self.layers],
871+
}
872+
873+
@classmethod
874+
def deserialize(cls, data: dict) -> "EmbeddingNet":
875+
"""Deserialize the network from a dict.
876+
877+
Parameters
878+
----------
879+
data : dict
880+
The dict to deserialize from.
881+
"""
882+
data = data.copy()
883+
check_version_compatibility(data.pop("@version", 1), 2, 1)
884+
data.pop("@class", None)
885+
layers = data.pop("layers")
886+
obj = cls(**data)
887+
# Reinitialize layers from serialized data, using the same layer type
888+
# that __init__ created (respects subclass overrides via MRO).
889+
if obj.layers:
890+
layer_type = type(obj.layers[0])
891+
obj.layers = type(obj.layers)(
892+
[layer_type.deserialize(layer) for layer in layers]
893+
)
894+
else:
895+
obj.layers = type(obj.layers)([])
896+
return obj
789897

790898

791899
def make_fitting_network(

deepmd/pt/utils/env.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
# only linux
3535
ncpus = len(os.sched_getaffinity(0))
3636
except AttributeError:
37-
ncpus = os.cpu_count()
37+
ncpus = os.cpu_count() or 1
3838
NUM_WORKERS = int(os.environ.get("NUM_WORKERS", min(4, ncpus)))
3939
if multiprocessing.get_start_method() != "fork":
4040
# spawn or forkserver does not support NUM_WORKERS > 0 for DataLoader

deepmd/pt_expt/utils/env.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
# only linux
3535
ncpus = len(os.sched_getaffinity(0))
3636
except AttributeError:
37-
ncpus = os.cpu_count()
37+
ncpus = os.cpu_count() or 1
3838
NUM_WORKERS = int(os.environ.get("NUM_WORKERS", min(4, ncpus)))
3939
if multiprocessing.get_start_method() != "fork":
4040
# spawn or forkserver does not support NUM_WORKERS > 0 for DataLoader

deepmd/pt_expt/utils/network.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@
1010
from deepmd.dpmodel.common import (
1111
NativeOP,
1212
)
13+
from deepmd.dpmodel.utils.network import EmbeddingNet as EmbeddingNetDP
1314
from deepmd.dpmodel.utils.network import LayerNorm as LayerNormDP
1415
from deepmd.dpmodel.utils.network import NativeLayer as NativeLayerDP
1516
from deepmd.dpmodel.utils.network import NetworkCollection as NetworkCollectionDP
1617
from deepmd.dpmodel.utils.network import (
17-
make_embedding_network,
1818
make_fitting_network,
1919
make_multilayer_network,
2020
)
@@ -91,8 +91,27 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
9191
return self.call(x)
9292

9393

94-
class EmbeddingNet(make_embedding_network(NativeNet, NativeLayer)):
95-
pass
94+
class EmbeddingNet(EmbeddingNetDP, torch.nn.Module):
95+
def __init__(self, *args: Any, **kwargs: Any) -> None:
96+
torch.nn.Module.__init__(self)
97+
EmbeddingNetDP.__init__(self, *args, **kwargs)
98+
# EmbeddingNetDP.__init__ creates dpmodel NativeLayer instances.
99+
# Convert to pt_expt NativeLayer and wrap in ModuleList.
100+
self.layers = torch.nn.ModuleList(
101+
[NativeLayer.deserialize(layer.serialize()) for layer in self.layers]
102+
)
103+
104+
def __call__(self, *args: Any, **kwargs: Any) -> Any:
105+
return torch.nn.Module.__call__(self, *args, **kwargs)
106+
107+
def forward(self, x: torch.Tensor) -> torch.Tensor:
108+
return self.call(x)
109+
110+
111+
register_dpmodel_mapping(
112+
EmbeddingNetDP,
113+
lambda v: EmbeddingNet.deserialize(v.serialize()),
114+
)
96115

97116

98117
class FittingNet(make_fitting_network(EmbeddingNet, NativeNet, NativeLayer)):

source/tests/common/dpmodel/test_network.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,114 @@ def test_embedding_net(self) -> None:
180180
inp = np.ones([ni], dtype=get_xp_precision(np, prec))
181181
np.testing.assert_allclose(en0.call(inp), en1.call(inp))
182182

183+
def test_is_concrete_class(self) -> None:
184+
"""Verify EmbeddingNet is a concrete class, not factory-generated."""
185+
in_dim = 4
186+
neuron = [8, 16, 32]
187+
net = EmbeddingNet(
188+
in_dim=in_dim,
189+
neuron=neuron,
190+
activation_function="tanh",
191+
resnet_dt=True,
192+
precision="float64",
193+
)
194+
# Check it's the actual EmbeddingNet class, not a dynamic class
195+
self.assertEqual(net.__class__.__name__, "EmbeddingNet")
196+
self.assertEqual(net.__class__.__module__, "deepmd.dpmodel.utils.network")
197+
# Verify it has the expected attributes
198+
self.assertEqual(net.in_dim, in_dim)
199+
self.assertEqual(net.neuron, neuron)
200+
self.assertEqual(net.activation_function, "tanh")
201+
self.assertEqual(net.resnet_dt, True)
202+
self.assertEqual(len(net.layers), len(neuron))
203+
204+
def test_forward_pass(self) -> None:
205+
"""Test EmbeddingNet forward pass produces correct shapes."""
206+
in_dim = 4
207+
neuron = [8, 16, 32]
208+
net = EmbeddingNet(
209+
in_dim=in_dim,
210+
neuron=neuron,
211+
activation_function="tanh",
212+
resnet_dt=True,
213+
precision="float64",
214+
)
215+
rng = np.random.default_rng()
216+
x = rng.standard_normal((5, in_dim))
217+
out = net.call(x)
218+
self.assertEqual(out.shape, (5, neuron[-1]))
219+
self.assertEqual(out.dtype, np.float64)
220+
221+
def test_trainable_parameter_variants(self) -> None:
222+
"""Test EmbeddingNet with different trainable configurations."""
223+
in_dim = 4
224+
neuron = [8, 16]
225+
226+
# All trainable
227+
net_trainable = EmbeddingNet(
228+
in_dim=in_dim,
229+
neuron=neuron,
230+
trainable=True,
231+
)
232+
for layer in net_trainable.layers:
233+
self.assertTrue(layer.trainable)
234+
235+
# All frozen
236+
net_frozen = EmbeddingNet(
237+
in_dim=in_dim,
238+
neuron=neuron,
239+
trainable=False,
240+
)
241+
for layer in net_frozen.layers:
242+
self.assertFalse(layer.trainable)
243+
244+
# Mixed trainable
245+
net_mixed = EmbeddingNet(
246+
in_dim=in_dim,
247+
neuron=neuron,
248+
trainable=[True, False],
249+
)
250+
self.assertTrue(net_mixed.layers[0].trainable)
251+
self.assertFalse(net_mixed.layers[1].trainable)
252+
253+
def test_empty_layers_round_trip(self) -> None:
254+
"""Test EmbeddingNet with empty neuron list (edge case for deserialize).
255+
256+
This tests the fix for IndexError when neuron=[] results in empty layers.
257+
The deserialize method should handle this case without trying to access
258+
layers[0] when the list is empty.
259+
"""
260+
in_dim = 4
261+
neuron = [] # Empty neuron list
262+
263+
# Create network with empty layers
264+
net = EmbeddingNet(
265+
in_dim=in_dim,
266+
neuron=neuron,
267+
activation_function="tanh",
268+
resnet_dt=True,
269+
precision="float64",
270+
)
271+
272+
# Verify it has no layers
273+
self.assertEqual(len(net.layers), 0)
274+
275+
# Serialize and deserialize
276+
serialized = net.serialize()
277+
net_restored = EmbeddingNet.deserialize(serialized)
278+
279+
# Verify restored network also has no layers
280+
self.assertEqual(len(net_restored.layers), 0)
281+
self.assertEqual(net_restored.in_dim, in_dim)
282+
self.assertEqual(net_restored.neuron, neuron)
283+
284+
# Verify forward pass works (should return input unchanged)
285+
rng = np.random.default_rng()
286+
x = rng.standard_normal((5, in_dim))
287+
out = net_restored.call(x)
288+
# With no layers, output should equal input
289+
np.testing.assert_allclose(out, x)
290+
183291

184292
class TestFittingNet(unittest.TestCase):
185293
def test_fitting_net(self) -> None:

0 commit comments

Comments
 (0)