Commit 2d7fdc5
refact(dpmodel,pt_expt): fitting net (#5207)
# FittingNet Refactoring: Factory Function to Concrete Class
## Summary
This refactoring converts `FittingNet` from a factory-generated dynamic
class to a concrete class in the dpmodel backend, following the same
pattern as the EmbeddingNet refactoring. This enables the auto-detection
registry mechanism in pt_expt to work seamlessly with FittingNet.
This PR is considered after
#5194 and
#5204
## Motivation
**Before**: `FittingNet` was created by a factory function
`make_fitting_network(EmbeddingNet, NativeNet, NativeLayer)`, producing
a dynamically-typed class. This caused:
1. **Cannot be registered**: Dynamic classes can't be imported or
registered at module import time in the pt_expt registry
2. **Type matching fails**: Each call to `make_fitting_network` creates
a new class type, so registry lookup by type fails
**After**: `FittingNet` is now a concrete class that can be registered
in the pt_expt auto-conversion registry.
## Changes
### 1. dpmodel: Concrete `FittingNet` class
**File**: `deepmd/dpmodel/utils/network.py`
- Created concrete `FittingNet(EmbeddingNet)` class
- Moved constructor logic from factory into `__init__`
- Fixed `deserialize` to use `type(obj.layers[0])` instead of hardcoding
`T_Network.__init__(obj, layers)`, allowing pt_expt subclass to preserve
its converted torch layers
- Kept `make_fitting_network` factory for backwards compatibility (for
pt/pd backends)
```python
class FittingNet(EmbeddingNet):
"""The fitting network."""
def __init__(self, in_dim, out_dim, neuron=[24, 48, 96],
activation_function="tanh", resnet_dt=False,
precision=DEFAULT_PRECISION, bias_out=True,
seed=None, trainable=True):
# Handle trainable parameter
if trainable is None:
trainable = [True] * (len(neuron) + 1)
elif isinstance(trainable, bool):
trainable = [trainable] * (len(neuron) + 1)
# Initialize embedding layers via parent
super().__init__(
in_dim, neuron=neuron,
activation_function=activation_function,
resnet_dt=resnet_dt, precision=precision,
seed=seed, trainable=trainable[:-1]
)
# Add output layer
i_in = neuron[-1] if len(neuron) > 0 else in_dim
self.layers.append(
NativeLayer(
i_in, out_dim, bias=bias_out,
use_timestep=False, activation_function=None,
resnet=False, precision=precision,
seed=child_seed(seed, len(neuron)),
trainable=trainable[-1]
)
)
self.out_dim = out_dim
self.bias_out = bias_out
@classmethod
def deserialize(cls, data):
data = data.copy()
check_version_compatibility(data.pop("@Version", 1), 1, 1)
data.pop("@Class", None)
layers = data.pop("layers")
obj = cls(**data)
# Use type(obj.layers[0]) to respect subclass layer types
layer_type = type(obj.layers[0])
obj.layers = type(obj.layers)(
[layer_type.deserialize(layer) for layer in layers]
)
return obj
```
### 2. pt_expt: Wrapper and registration
**File**: `deepmd/pt_expt/utils/network.py`
- Added import: `from deepmd.dpmodel.utils.network import FittingNet as
FittingNetDP`
- Created `FittingNet(FittingNetDP, torch.nn.Module)` wrapper
- Converts dpmodel layers to pt_expt `NativeLayer` (torch modules) in
`__init__`
- Registered in auto-conversion registry
```python
from deepmd.dpmodel.utils.network import FittingNet as FittingNetDP
class FittingNet(FittingNetDP, torch.nn.Module):
def __init__(self, *args: Any, **kwargs: Any) -> None:
torch.nn.Module.__init__(self)
FittingNetDP.__init__(self, *args, **kwargs)
# Convert dpmodel layers to pt_expt NativeLayer
self.layers = torch.nn.ModuleList(
[NativeLayer.deserialize(layer.serialize()) for layer in self.layers]
)
def __call__(self, *args: Any, **kwargs: Any) -> Any:
return torch.nn.Module.__call__(self, *args, **kwargs)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.call(x)
register_dpmodel_mapping(
FittingNetDP,
lambda v: FittingNet.deserialize(v.serialize()),
)
```
## Tests
### dpmodel tests
**File**: `source/tests/common/dpmodel/test_network.py`
Added to `TestFittingNet` class:
1. **`test_fitting_net`**: Original roundtrip serialization test
(already existed)
2. **`test_is_concrete_class`**: Verifies `FittingNet` is now a concrete
class, not factory output
3. **`test_forward_pass`**: Tests dpmodel forward pass produces correct
output shapes (single and batch)
4. **`test_trainable_parameter_variants`**: Tests different trainable
configurations (all trainable, all frozen, mixed)
### pt_expt integration tests
**File**: `source/tests/pt_expt/utils/test_network.py`
Created `TestFittingNetRefactor` test suite with 4 tests:
1. **`test_pt_expt_fitting_net_wraps_dpmodel`**: Verifies pt_expt
wrapper inherits correctly and converts layers
2. **`test_pt_expt_fitting_net_forward`**: Tests pt_expt forward pass
returns torch.Tensor with correct shape
3. **`test_serialization_round_trip_pt_expt`**: Tests pt_expt
serialize/deserialize round-trip
4. **`test_registry_converts_dpmodel_to_pt_expt`**: Tests
`try_convert_module` auto-converts dpmodel to pt_expt
## Verification
All tests pass:
```bash
# dpmodel network tests (includes new FittingNet tests)
python -m pytest source/tests/common/dpmodel/test_network.py -v
# 19 passed in 0.56s (was 16, added 3 FittingNet tests)
# dpmodel FittingNet tests specifically
python -m pytest source/tests/common/dpmodel/test_network.py::TestFittingNet -v
# 4 passed in 0.44s
# pt_expt network tests (EmbeddingNet + FittingNet)
python -m pytest source/tests/pt_expt/utils/test_network.py -v
# 14 passed in 0.45s
# Descriptor tests (verify refactoring doesn't break existing code)
python -m pytest source/tests/pt_expt/descriptor/ -v
# 8 passed in 5.43s
```
## Benefits
1. **Type-based auto-detection**: FittingNet now works with the registry
mechanism
2. **Consistency**: Same pattern as EmbeddingNet and other dpmodel
classes
3. **Maintainability**: Single source of truth for FittingNet in dpmodel
4. **Future-proof**: Any dpmodel FittingNet instances can be
auto-converted to pt_expt
## Backward Compatibility
- Serialization format unchanged (version 1)
- All existing tests pass
- `make_fitting_network` factory kept for pt/pd backends
- No changes to public API
## Files Changed
### Modified
- `deepmd/dpmodel/utils/network.py`: Concrete FittingNet class +
deserialize fix
- `deepmd/pt_expt/utils/network.py`: FittingNet wrapper + registration
- `source/tests/common/dpmodel/test_network.py`: Added dpmodel
FittingNet tests (3 new tests)
- `source/tests/pt_expt/utils/test_network.py`: Added pt_expt
integration tests (4 new tests)
### Pattern
This refactoring follows the exact same pattern as
`EMBEDDING_NET_REFACTOR.md`:
1. Convert factory-generated class to concrete class in dpmodel
2. Fix `deserialize` to use `type(obj.layers[0])`
3. Create pt_expt wrapper with layer conversion in `__init__`
4. Register with `register_dpmodel_mapping`
5. Add comprehensive tests
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit
## Release Notes
* **New Features**
* Added PyTorch experimental descriptor implementations for SeT and
SeTTebd with full export/tracing support
* Introduced PyTorch-compatible wrapper classes for network components
enabling seamless integration with PyTorch workflows
* **Improvements**
* Enhanced device-aware tensor operations across all descriptors for
better multi-device support
* Improved error handling with explicit error messages when statistics
are missing instead of silent failures
* Refactored FittingNet as a concrete class with explicit public
interface
* **Tests**
* Added comprehensive test coverage for new PyTorch experimental
descriptors and network wrappers
* Added unit tests validating serialization, deserialization, and
forward pass behavior
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
---------
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
Co-authored-by: Han Wang <wang_han@iapcm.ac.cn>
Co-authored-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>1 parent 02bd1fc commit 2d7fdc5
4 files changed
Lines changed: 352 additions & 4 deletions
File tree
- deepmd
- dpmodel/utils
- pt_expt/utils
- source/tests
- common/dpmodel
- pt_expt/utils
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1003 | 1003 | | |
1004 | 1004 | | |
1005 | 1005 | | |
1006 | | - | |
| 1006 | + | |
| 1007 | + | |
| 1008 | + | |
| 1009 | + | |
| 1010 | + | |
| 1011 | + | |
| 1012 | + | |
| 1013 | + | |
| 1014 | + | |
| 1015 | + | |
| 1016 | + | |
| 1017 | + | |
| 1018 | + | |
| 1019 | + | |
| 1020 | + | |
| 1021 | + | |
| 1022 | + | |
| 1023 | + | |
| 1024 | + | |
| 1025 | + | |
| 1026 | + | |
| 1027 | + | |
| 1028 | + | |
| 1029 | + | |
| 1030 | + | |
| 1031 | + | |
| 1032 | + | |
| 1033 | + | |
| 1034 | + | |
| 1035 | + | |
| 1036 | + | |
| 1037 | + | |
| 1038 | + | |
| 1039 | + | |
| 1040 | + | |
| 1041 | + | |
| 1042 | + | |
| 1043 | + | |
| 1044 | + | |
| 1045 | + | |
| 1046 | + | |
| 1047 | + | |
| 1048 | + | |
| 1049 | + | |
| 1050 | + | |
| 1051 | + | |
| 1052 | + | |
| 1053 | + | |
| 1054 | + | |
| 1055 | + | |
| 1056 | + | |
| 1057 | + | |
| 1058 | + | |
| 1059 | + | |
| 1060 | + | |
| 1061 | + | |
| 1062 | + | |
| 1063 | + | |
| 1064 | + | |
| 1065 | + | |
| 1066 | + | |
| 1067 | + | |
| 1068 | + | |
| 1069 | + | |
| 1070 | + | |
| 1071 | + | |
| 1072 | + | |
| 1073 | + | |
| 1074 | + | |
| 1075 | + | |
| 1076 | + | |
| 1077 | + | |
| 1078 | + | |
| 1079 | + | |
| 1080 | + | |
| 1081 | + | |
| 1082 | + | |
| 1083 | + | |
| 1084 | + | |
| 1085 | + | |
| 1086 | + | |
| 1087 | + | |
| 1088 | + | |
| 1089 | + | |
| 1090 | + | |
| 1091 | + | |
| 1092 | + | |
| 1093 | + | |
| 1094 | + | |
| 1095 | + | |
| 1096 | + | |
| 1097 | + | |
| 1098 | + | |
| 1099 | + | |
| 1100 | + | |
| 1101 | + | |
| 1102 | + | |
| 1103 | + | |
| 1104 | + | |
| 1105 | + | |
| 1106 | + | |
| 1107 | + | |
| 1108 | + | |
| 1109 | + | |
| 1110 | + | |
| 1111 | + | |
| 1112 | + | |
| 1113 | + | |
| 1114 | + | |
| 1115 | + | |
| 1116 | + | |
| 1117 | + | |
1007 | 1118 | | |
1008 | 1119 | | |
1009 | 1120 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
11 | 11 | | |
12 | 12 | | |
13 | 13 | | |
| 14 | + | |
14 | 15 | | |
15 | 16 | | |
16 | 17 | | |
17 | 18 | | |
18 | | - | |
19 | 19 | | |
20 | 20 | | |
21 | 21 | | |
| |||
114 | 114 | | |
115 | 115 | | |
116 | 116 | | |
117 | | - | |
118 | | - | |
| 117 | + | |
| 118 | + | |
| 119 | + | |
| 120 | + | |
| 121 | + | |
| 122 | + | |
| 123 | + | |
| 124 | + | |
| 125 | + | |
| 126 | + | |
| 127 | + | |
| 128 | + | |
| 129 | + | |
| 130 | + | |
| 131 | + | |
| 132 | + | |
| 133 | + | |
| 134 | + | |
| 135 | + | |
| 136 | + | |
119 | 137 | | |
120 | 138 | | |
121 | 139 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
313 | 313 | | |
314 | 314 | | |
315 | 315 | | |
| 316 | + | |
| 317 | + | |
| 318 | + | |
| 319 | + | |
| 320 | + | |
| 321 | + | |
| 322 | + | |
| 323 | + | |
| 324 | + | |
| 325 | + | |
| 326 | + | |
| 327 | + | |
| 328 | + | |
| 329 | + | |
| 330 | + | |
| 331 | + | |
| 332 | + | |
| 333 | + | |
| 334 | + | |
| 335 | + | |
| 336 | + | |
| 337 | + | |
| 338 | + | |
| 339 | + | |
| 340 | + | |
| 341 | + | |
| 342 | + | |
| 343 | + | |
| 344 | + | |
| 345 | + | |
| 346 | + | |
| 347 | + | |
| 348 | + | |
| 349 | + | |
| 350 | + | |
| 351 | + | |
| 352 | + | |
| 353 | + | |
| 354 | + | |
| 355 | + | |
| 356 | + | |
| 357 | + | |
| 358 | + | |
| 359 | + | |
| 360 | + | |
| 361 | + | |
| 362 | + | |
| 363 | + | |
| 364 | + | |
| 365 | + | |
| 366 | + | |
| 367 | + | |
| 368 | + | |
| 369 | + | |
| 370 | + | |
| 371 | + | |
| 372 | + | |
| 373 | + | |
| 374 | + | |
| 375 | + | |
| 376 | + | |
| 377 | + | |
| 378 | + | |
| 379 | + | |
| 380 | + | |
| 381 | + | |
| 382 | + | |
| 383 | + | |
| 384 | + | |
| 385 | + | |
| 386 | + | |
| 387 | + | |
| 388 | + | |
| 389 | + | |
| 390 | + | |
| 391 | + | |
| 392 | + | |
| 393 | + | |
| 394 | + | |
| 395 | + | |
| 396 | + | |
| 397 | + | |
| 398 | + | |
| 399 | + | |
| 400 | + | |
| 401 | + | |
| 402 | + | |
| 403 | + | |
| 404 | + | |
| 405 | + | |
| 406 | + | |
| 407 | + | |
| 408 | + | |
| 409 | + | |
| 410 | + | |
| 411 | + | |
| 412 | + | |
| 413 | + | |
316 | 414 | | |
317 | 415 | | |
318 | 416 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
281 | 281 | | |
282 | 282 | | |
283 | 283 | | |
| 284 | + | |
| 285 | + | |
| 286 | + | |
| 287 | + | |
| 288 | + | |
| 289 | + | |
| 290 | + | |
| 291 | + | |
| 292 | + | |
| 293 | + | |
| 294 | + | |
| 295 | + | |
| 296 | + | |
| 297 | + | |
| 298 | + | |
| 299 | + | |
| 300 | + | |
| 301 | + | |
| 302 | + | |
| 303 | + | |
| 304 | + | |
| 305 | + | |
| 306 | + | |
| 307 | + | |
| 308 | + | |
| 309 | + | |
| 310 | + | |
| 311 | + | |
| 312 | + | |
| 313 | + | |
| 314 | + | |
| 315 | + | |
| 316 | + | |
| 317 | + | |
| 318 | + | |
| 319 | + | |
| 320 | + | |
| 321 | + | |
| 322 | + | |
| 323 | + | |
| 324 | + | |
| 325 | + | |
| 326 | + | |
| 327 | + | |
| 328 | + | |
| 329 | + | |
| 330 | + | |
| 331 | + | |
| 332 | + | |
| 333 | + | |
| 334 | + | |
| 335 | + | |
| 336 | + | |
| 337 | + | |
| 338 | + | |
| 339 | + | |
| 340 | + | |
| 341 | + | |
| 342 | + | |
| 343 | + | |
| 344 | + | |
| 345 | + | |
| 346 | + | |
| 347 | + | |
| 348 | + | |
| 349 | + | |
| 350 | + | |
| 351 | + | |
| 352 | + | |
| 353 | + | |
| 354 | + | |
| 355 | + | |
| 356 | + | |
| 357 | + | |
| 358 | + | |
| 359 | + | |
| 360 | + | |
| 361 | + | |
| 362 | + | |
| 363 | + | |
| 364 | + | |
| 365 | + | |
| 366 | + | |
| 367 | + | |
| 368 | + | |
| 369 | + | |
| 370 | + | |
| 371 | + | |
| 372 | + | |
| 373 | + | |
| 374 | + | |
| 375 | + | |
| 376 | + | |
| 377 | + | |
| 378 | + | |
| 379 | + | |
| 380 | + | |
| 381 | + | |
| 382 | + | |
| 383 | + | |
| 384 | + | |
| 385 | + | |
| 386 | + | |
| 387 | + | |
| 388 | + | |
| 389 | + | |
| 390 | + | |
| 391 | + | |
| 392 | + | |
| 393 | + | |
| 394 | + | |
| 395 | + | |
| 396 | + | |
| 397 | + | |
| 398 | + | |
| 399 | + | |
| 400 | + | |
| 401 | + | |
| 402 | + | |
| 403 | + | |
| 404 | + | |
0 commit comments