Commit cd67bbe
refact(dpmodel,pt_expt): embedding net (#5205)
# EmbeddingNet Refactoring: Factory Function to Concrete Class
## Summary
This refactoring converts `EmbeddingNet` from a factory-generated
dynamic class to a concrete class in the dpmodel backend. This change
enables the auto-detection registry mechanism in pt_expt to work
seamlessly with EmbeddingNet attributes.
This PR is considered after #5194 and #5204
## Motivation
**Before**: `EmbeddingNet` was created by a factory function
`make_embedding_network(NativeNet, NativeLayer)`, producing a
dynamically-typed class `make_embedding_network.<locals>.EN`. This
caused two problems:
1. **Cannot be registered**: Dynamic classes can't be imported or
registered at module import time in the pt_expt registry
2. **Name-based hacks required**: pt_expt wrappers had to explicitly
check for `name == "embedding_net"` in `__setattr__` instead of using
the type-based auto-detection mechanism
**After**: `EmbeddingNet` is now a concrete class that can be registered
in the pt_expt auto-conversion registry, eliminating the need for
name-based special cases.
## Changes
### 1. dpmodel: Concrete `EmbeddingNet` class
**File**: `deepmd/dpmodel/utils/network.py`
- Replaced factory-generated class with concrete
`EmbeddingNet(NativeNet)` class
- Moved constructor logic from factory into `__init__`
- Fixed `deserialize` to use `type(obj.layers[0])` instead of hardcoding
`super(EmbeddingNet, obj)`, allowing pt_expt subclass to preserve its
converted torch layers
- Kept `make_embedding_network` factory for pt/pd backends that use
different base classes (MLP)
```python
class EmbeddingNet(NativeNet):
"""The embedding network."""
def __init__(self, in_dim, neuron=[24, 48, 96], activation_function="tanh",
resnet_dt=False, precision=DEFAULT_PRECISION, seed=None,
bias=True, trainable=True):
layers = []
i_in = in_dim
if isinstance(trainable, bool):
trainable = [trainable] * len(neuron)
for idx, ii in enumerate(neuron):
i_ot = ii
layers.append(
NativeLayer(
i_in, i_ot, bias=bias, use_timestep=resnet_dt,
activation_function=activation_function, resnet=True,
precision=precision, seed=child_seed(seed, idx),
trainable=trainable[idx]
).serialize()
)
i_in = i_ot
super().__init__(layers)
self.in_dim = in_dim
self.neuron = neuron
self.activation_function = activation_function
self.resnet_dt = resnet_dt
self.precision = precision
self.bias = bias
@classmethod
def deserialize(cls, data):
data = data.copy()
check_version_compatibility(data.pop("@Version", 1), 2, 1)
data.pop("@Class", None)
layers = data.pop("layers")
obj = cls(**data)
# Use type(obj.layers[0]) to respect subclass layer types
layer_type = type(obj.layers[0])
obj.layers = type(obj.layers)(
[layer_type.deserialize(layer) for layer in layers]
)
return obj
```
### 2. pt_expt: Wrapper and registration
**File**: `deepmd/pt_expt/utils/network.py`
- Created `EmbeddingNet(EmbeddingNetDP, torch.nn.Module)` wrapper
- Converts dpmodel layers to pt_expt `NativeLayer` (torch modules) in
`__init__`
- Registered in auto-conversion registry
```python
class EmbeddingNet(EmbeddingNetDP, torch.nn.Module):
def __init__(self, *args: Any, **kwargs: Any) -> None:
torch.nn.Module.__init__(self)
EmbeddingNetDP.__init__(self, *args, **kwargs)
# Convert dpmodel layers to pt_expt NativeLayer
self.layers = torch.nn.ModuleList(
[NativeLayer.deserialize(layer.serialize()) for layer in self.layers]
)
def __call__(self, *args: Any, **kwargs: Any) -> Any:
return torch.nn.Module.__call__(self, *args, **kwargs)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.call(x)
register_dpmodel_mapping(
EmbeddingNetDP,
lambda v: EmbeddingNet.deserialize(v.serialize()),
)
```
### 3. TypeEmbedNet: Simplified to use registry
**File**: `deepmd/pt_expt/utils/type_embed.py`
- No longer needs name-based `embedding_net` check in `__setattr__`
- Uses common `dpmodel_setattr` which auto-converts via registry
- Imports `network` module to ensure `EmbeddingNet` registration happens
first
```python
class TypeEmbedNet(TypeEmbedNetDP, torch.nn.Module):
def __setattr__(self, name: str, value: Any) -> None:
# Auto-converts embedding_net via registry
handled, value = dpmodel_setattr(self, name, value)
if not handled:
super().__setattr__(name, value)
```
## Tests
### dpmodel tests
**File**: `source/tests/common/dpmodel/test_network.py`
Added to `TestEmbeddingNet` class:
1. **`test_is_concrete_class`**: Verifies `EmbeddingNet` is now a
concrete class, not factory output
2. **`test_forward_pass`**: Tests dpmodel forward pass produces correct
shapes
3. **`test_trainable_parameter_variants`**: Tests different trainable
configurations (all trainable, all frozen, mixed)
(The existing `test_embedding_net` test already covers
serialization/deserialization round-trip)
### pt_expt integration tests
**File**: `source/tests/pt_expt/utils/test_network.py`
Created `TestEmbeddingNetRefactor` test suite with 8 tests:
1. **`test_pt_expt_embedding_net_wraps_dpmodel`**: Verifies pt_expt
wrapper inherits correctly and converts layers
2. **`test_pt_expt_embedding_net_forward`**: Tests pt_expt forward pass
returns torch.Tensor
3. **`test_serialization_round_trip_pt_expt`**: Tests pt_expt
serialize/deserialize
4. **`test_deserialize_preserves_layer_type`**: Tests the key fix -
`deserialize` uses `type(obj.layers[0])` to preserve pt_expt's torch
layers
5. **`test_cross_backend_consistency`**: Tests numerical consistency
between dpmodel and pt_expt
6. **`test_registry_converts_dpmodel_to_pt_expt`**: Tests
`try_convert_module` auto-converts dpmodel to pt_expt
7. **`test_auto_conversion_in_setattr`**: Tests `dpmodel_setattr`
auto-converts EmbeddingNet attributes
8. **`test_trainable_parameter_handling`**: Tests trainable vs frozen
parameters work correctly in pt_expt
## Verification
All tests pass:
```bash
# dpmodel EmbeddingNet tests
python -m pytest source/tests/common/dpmodel/test_network.py::TestEmbeddingNet -v
# 4 passed in 0.41s
# pt_expt EmbeddingNet integration tests
python -m pytest source/tests/pt_expt/utils/test_network.py::TestEmbeddingNetRefactor -v
# 8 passed in 0.41s
# All pt_expt network tests
python -m pytest source/tests/pt_expt/utils/test_network.py -v
# 10 passed in 0.41s
# Descriptor tests (verify refactoring doesn't break existing code)
python -m pytest source/tests/pt_expt/descriptor/test_se_e2_a.py -v -k consistency
# 1 passed
python -m pytest source/tests/universal/pt_expt/descriptor/test_descriptor.py -v
# 8 passed in 3.27s
```
## Benefits
1. **Type-based auto-detection**: No more name-based special cases in
`__setattr__`
2. **Maintainability**: Single source of truth for EmbeddingNet in
dpmodel
3. **Consistency**: Same pattern as other dpmodel classes
(AtomExcludeMask, NetworkCollection, etc.)
4. **Future-proof**: New attributes in dpmodel automatically work in
pt_expt via registry
## Backward Compatibility
- Serialization format unchanged (version 2.1)
- All existing tests pass
- `make_embedding_network` factory kept for pt/pd backends
- No changes to public API
## Files Changed
### Modified
- `deepmd/dpmodel/utils/network.py`: Concrete EmbeddingNet class +
deserialize fix
- `deepmd/pt_expt/utils/network.py`: EmbeddingNet wrapper + registration
- `deepmd/pt_expt/utils/type_embed.py`: Simplified to use registry
- `source/tests/common/dpmodel/test_network.py`: Added dpmodel
EmbeddingNet tests (3 new tests)
- `source/tests/pt_expt/utils/test_network.py`: Added pt_expt
integration tests (8 new tests)
### No changes required
- All descriptor wrappers (se_e2_a, se_r, se_t, se_t_tebd) automatically
work via registry
- No changes to dpmodel logic or array_api_compat code
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit
## Release Notes
* **New Features**
* Added PyTorch compatibility layer enabling DPModel neural network
components to be used with PyTorch workflows for training and inference
* Enhanced embedding network with explicit serialization and
deserialization capabilities
* **Refactor**
* Restructured embedding network with explicit class design for improved
type stability and control flow management
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
---------
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
Co-authored-by: Han Wang <wang_han@iapcm.ac.cn>
Co-authored-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>1 parent d310532 commit cd67bbe
6 files changed
Lines changed: 503 additions & 6 deletions
File tree
- deepmd
- dpmodel/utils
- pt_expt/utils
- pt/utils
- source/tests
- common/dpmodel
- pt_expt/utils
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
785 | 785 | | |
786 | 786 | | |
787 | 787 | | |
788 | | - | |
| 788 | + | |
| 789 | + | |
| 790 | + | |
| 791 | + | |
| 792 | + | |
| 793 | + | |
| 794 | + | |
| 795 | + | |
| 796 | + | |
| 797 | + | |
| 798 | + | |
| 799 | + | |
| 800 | + | |
| 801 | + | |
| 802 | + | |
| 803 | + | |
| 804 | + | |
| 805 | + | |
| 806 | + | |
| 807 | + | |
| 808 | + | |
| 809 | + | |
| 810 | + | |
| 811 | + | |
| 812 | + | |
| 813 | + | |
| 814 | + | |
| 815 | + | |
| 816 | + | |
| 817 | + | |
| 818 | + | |
| 819 | + | |
| 820 | + | |
| 821 | + | |
| 822 | + | |
| 823 | + | |
| 824 | + | |
| 825 | + | |
| 826 | + | |
| 827 | + | |
| 828 | + | |
| 829 | + | |
| 830 | + | |
| 831 | + | |
| 832 | + | |
| 833 | + | |
| 834 | + | |
| 835 | + | |
| 836 | + | |
| 837 | + | |
| 838 | + | |
| 839 | + | |
| 840 | + | |
| 841 | + | |
| 842 | + | |
| 843 | + | |
| 844 | + | |
| 845 | + | |
| 846 | + | |
| 847 | + | |
| 848 | + | |
| 849 | + | |
| 850 | + | |
| 851 | + | |
| 852 | + | |
| 853 | + | |
| 854 | + | |
| 855 | + | |
| 856 | + | |
| 857 | + | |
| 858 | + | |
| 859 | + | |
| 860 | + | |
| 861 | + | |
| 862 | + | |
| 863 | + | |
| 864 | + | |
| 865 | + | |
| 866 | + | |
| 867 | + | |
| 868 | + | |
| 869 | + | |
| 870 | + | |
| 871 | + | |
| 872 | + | |
| 873 | + | |
| 874 | + | |
| 875 | + | |
| 876 | + | |
| 877 | + | |
| 878 | + | |
| 879 | + | |
| 880 | + | |
| 881 | + | |
| 882 | + | |
| 883 | + | |
| 884 | + | |
| 885 | + | |
| 886 | + | |
| 887 | + | |
| 888 | + | |
| 889 | + | |
| 890 | + | |
| 891 | + | |
| 892 | + | |
| 893 | + | |
| 894 | + | |
| 895 | + | |
| 896 | + | |
789 | 897 | | |
790 | 898 | | |
791 | 899 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
34 | 34 | | |
35 | 35 | | |
36 | 36 | | |
37 | | - | |
| 37 | + | |
38 | 38 | | |
39 | 39 | | |
40 | 40 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
34 | 34 | | |
35 | 35 | | |
36 | 36 | | |
37 | | - | |
| 37 | + | |
38 | 38 | | |
39 | 39 | | |
40 | 40 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
10 | 10 | | |
11 | 11 | | |
12 | 12 | | |
| 13 | + | |
13 | 14 | | |
14 | 15 | | |
15 | 16 | | |
16 | 17 | | |
17 | | - | |
18 | 18 | | |
19 | 19 | | |
20 | 20 | | |
| |||
91 | 91 | | |
92 | 92 | | |
93 | 93 | | |
94 | | - | |
95 | | - | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
| 100 | + | |
| 101 | + | |
| 102 | + | |
| 103 | + | |
| 104 | + | |
| 105 | + | |
| 106 | + | |
| 107 | + | |
| 108 | + | |
| 109 | + | |
| 110 | + | |
| 111 | + | |
| 112 | + | |
| 113 | + | |
| 114 | + | |
96 | 115 | | |
97 | 116 | | |
98 | 117 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
180 | 180 | | |
181 | 181 | | |
182 | 182 | | |
| 183 | + | |
| 184 | + | |
| 185 | + | |
| 186 | + | |
| 187 | + | |
| 188 | + | |
| 189 | + | |
| 190 | + | |
| 191 | + | |
| 192 | + | |
| 193 | + | |
| 194 | + | |
| 195 | + | |
| 196 | + | |
| 197 | + | |
| 198 | + | |
| 199 | + | |
| 200 | + | |
| 201 | + | |
| 202 | + | |
| 203 | + | |
| 204 | + | |
| 205 | + | |
| 206 | + | |
| 207 | + | |
| 208 | + | |
| 209 | + | |
| 210 | + | |
| 211 | + | |
| 212 | + | |
| 213 | + | |
| 214 | + | |
| 215 | + | |
| 216 | + | |
| 217 | + | |
| 218 | + | |
| 219 | + | |
| 220 | + | |
| 221 | + | |
| 222 | + | |
| 223 | + | |
| 224 | + | |
| 225 | + | |
| 226 | + | |
| 227 | + | |
| 228 | + | |
| 229 | + | |
| 230 | + | |
| 231 | + | |
| 232 | + | |
| 233 | + | |
| 234 | + | |
| 235 | + | |
| 236 | + | |
| 237 | + | |
| 238 | + | |
| 239 | + | |
| 240 | + | |
| 241 | + | |
| 242 | + | |
| 243 | + | |
| 244 | + | |
| 245 | + | |
| 246 | + | |
| 247 | + | |
| 248 | + | |
| 249 | + | |
| 250 | + | |
| 251 | + | |
| 252 | + | |
| 253 | + | |
| 254 | + | |
| 255 | + | |
| 256 | + | |
| 257 | + | |
| 258 | + | |
| 259 | + | |
| 260 | + | |
| 261 | + | |
| 262 | + | |
| 263 | + | |
| 264 | + | |
| 265 | + | |
| 266 | + | |
| 267 | + | |
| 268 | + | |
| 269 | + | |
| 270 | + | |
| 271 | + | |
| 272 | + | |
| 273 | + | |
| 274 | + | |
| 275 | + | |
| 276 | + | |
| 277 | + | |
| 278 | + | |
| 279 | + | |
| 280 | + | |
| 281 | + | |
| 282 | + | |
| 283 | + | |
| 284 | + | |
| 285 | + | |
| 286 | + | |
| 287 | + | |
| 288 | + | |
| 289 | + | |
| 290 | + | |
183 | 291 | | |
184 | 292 | | |
185 | 293 | | |
| |||
0 commit comments