Skip to content

Commit 2d7fdc5

Browse files
wanghan-iapcmHan Wangnjzjzpre-commit-ci[bot]
authored
refact(dpmodel,pt_expt): fitting net (#5207)
# FittingNet Refactoring: Factory Function to Concrete Class ## Summary This refactoring converts `FittingNet` from a factory-generated dynamic class to a concrete class in the dpmodel backend, following the same pattern as the EmbeddingNet refactoring. This enables the auto-detection registry mechanism in pt_expt to work seamlessly with FittingNet. This PR is considered after #5194 and #5204 ## Motivation **Before**: `FittingNet` was created by a factory function `make_fitting_network(EmbeddingNet, NativeNet, NativeLayer)`, producing a dynamically-typed class. This caused: 1. **Cannot be registered**: Dynamic classes can't be imported or registered at module import time in the pt_expt registry 2. **Type matching fails**: Each call to `make_fitting_network` creates a new class type, so registry lookup by type fails **After**: `FittingNet` is now a concrete class that can be registered in the pt_expt auto-conversion registry. ## Changes ### 1. dpmodel: Concrete `FittingNet` class **File**: `deepmd/dpmodel/utils/network.py` - Created concrete `FittingNet(EmbeddingNet)` class - Moved constructor logic from factory into `__init__` - Fixed `deserialize` to use `type(obj.layers[0])` instead of hardcoding `T_Network.__init__(obj, layers)`, allowing pt_expt subclass to preserve its converted torch layers - Kept `make_fitting_network` factory for backwards compatibility (for pt/pd backends) ```python class FittingNet(EmbeddingNet): """The fitting network.""" def __init__(self, in_dim, out_dim, neuron=[24, 48, 96], activation_function="tanh", resnet_dt=False, precision=DEFAULT_PRECISION, bias_out=True, seed=None, trainable=True): # Handle trainable parameter if trainable is None: trainable = [True] * (len(neuron) + 1) elif isinstance(trainable, bool): trainable = [trainable] * (len(neuron) + 1) # Initialize embedding layers via parent super().__init__( in_dim, neuron=neuron, activation_function=activation_function, resnet_dt=resnet_dt, precision=precision, seed=seed, trainable=trainable[:-1] ) # Add output layer i_in = neuron[-1] if len(neuron) > 0 else in_dim self.layers.append( NativeLayer( i_in, out_dim, bias=bias_out, use_timestep=False, activation_function=None, resnet=False, precision=precision, seed=child_seed(seed, len(neuron)), trainable=trainable[-1] ) ) self.out_dim = out_dim self.bias_out = bias_out @classmethod def deserialize(cls, data): data = data.copy() check_version_compatibility(data.pop("@Version", 1), 1, 1) data.pop("@Class", None) layers = data.pop("layers") obj = cls(**data) # Use type(obj.layers[0]) to respect subclass layer types layer_type = type(obj.layers[0]) obj.layers = type(obj.layers)( [layer_type.deserialize(layer) for layer in layers] ) return obj ``` ### 2. pt_expt: Wrapper and registration **File**: `deepmd/pt_expt/utils/network.py` - Added import: `from deepmd.dpmodel.utils.network import FittingNet as FittingNetDP` - Created `FittingNet(FittingNetDP, torch.nn.Module)` wrapper - Converts dpmodel layers to pt_expt `NativeLayer` (torch modules) in `__init__` - Registered in auto-conversion registry ```python from deepmd.dpmodel.utils.network import FittingNet as FittingNetDP class FittingNet(FittingNetDP, torch.nn.Module): def __init__(self, *args: Any, **kwargs: Any) -> None: torch.nn.Module.__init__(self) FittingNetDP.__init__(self, *args, **kwargs) # Convert dpmodel layers to pt_expt NativeLayer self.layers = torch.nn.ModuleList( [NativeLayer.deserialize(layer.serialize()) for layer in self.layers] ) def __call__(self, *args: Any, **kwargs: Any) -> Any: return torch.nn.Module.__call__(self, *args, **kwargs) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.call(x) register_dpmodel_mapping( FittingNetDP, lambda v: FittingNet.deserialize(v.serialize()), ) ``` ## Tests ### dpmodel tests **File**: `source/tests/common/dpmodel/test_network.py` Added to `TestFittingNet` class: 1. **`test_fitting_net`**: Original roundtrip serialization test (already existed) 2. **`test_is_concrete_class`**: Verifies `FittingNet` is now a concrete class, not factory output 3. **`test_forward_pass`**: Tests dpmodel forward pass produces correct output shapes (single and batch) 4. **`test_trainable_parameter_variants`**: Tests different trainable configurations (all trainable, all frozen, mixed) ### pt_expt integration tests **File**: `source/tests/pt_expt/utils/test_network.py` Created `TestFittingNetRefactor` test suite with 4 tests: 1. **`test_pt_expt_fitting_net_wraps_dpmodel`**: Verifies pt_expt wrapper inherits correctly and converts layers 2. **`test_pt_expt_fitting_net_forward`**: Tests pt_expt forward pass returns torch.Tensor with correct shape 3. **`test_serialization_round_trip_pt_expt`**: Tests pt_expt serialize/deserialize round-trip 4. **`test_registry_converts_dpmodel_to_pt_expt`**: Tests `try_convert_module` auto-converts dpmodel to pt_expt ## Verification All tests pass: ```bash # dpmodel network tests (includes new FittingNet tests) python -m pytest source/tests/common/dpmodel/test_network.py -v # 19 passed in 0.56s (was 16, added 3 FittingNet tests) # dpmodel FittingNet tests specifically python -m pytest source/tests/common/dpmodel/test_network.py::TestFittingNet -v # 4 passed in 0.44s # pt_expt network tests (EmbeddingNet + FittingNet) python -m pytest source/tests/pt_expt/utils/test_network.py -v # 14 passed in 0.45s # Descriptor tests (verify refactoring doesn't break existing code) python -m pytest source/tests/pt_expt/descriptor/ -v # 8 passed in 5.43s ``` ## Benefits 1. **Type-based auto-detection**: FittingNet now works with the registry mechanism 2. **Consistency**: Same pattern as EmbeddingNet and other dpmodel classes 3. **Maintainability**: Single source of truth for FittingNet in dpmodel 4. **Future-proof**: Any dpmodel FittingNet instances can be auto-converted to pt_expt ## Backward Compatibility - Serialization format unchanged (version 1) - All existing tests pass - `make_fitting_network` factory kept for pt/pd backends - No changes to public API ## Files Changed ### Modified - `deepmd/dpmodel/utils/network.py`: Concrete FittingNet class + deserialize fix - `deepmd/pt_expt/utils/network.py`: FittingNet wrapper + registration - `source/tests/common/dpmodel/test_network.py`: Added dpmodel FittingNet tests (3 new tests) - `source/tests/pt_expt/utils/test_network.py`: Added pt_expt integration tests (4 new tests) ### Pattern This refactoring follows the exact same pattern as `EMBEDDING_NET_REFACTOR.md`: 1. Convert factory-generated class to concrete class in dpmodel 2. Fix `deserialize` to use `type(obj.layers[0])` 3. Create pt_expt wrapper with layer conversion in `__init__` 4. Register with `register_dpmodel_mapping` 5. Add comprehensive tests <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes * **New Features** * Added PyTorch experimental descriptor implementations for SeT and SeTTebd with full export/tracing support * Introduced PyTorch-compatible wrapper classes for network components enabling seamless integration with PyTorch workflows * **Improvements** * Enhanced device-aware tensor operations across all descriptors for better multi-device support * Improved error handling with explicit error messages when statistics are missing instead of silent failures * Refactored FittingNet as a concrete class with explicit public interface * **Tests** * Added comprehensive test coverage for new PyTorch experimental descriptors and network wrappers * Added unit tests validating serialization, deserialization, and forward pass behavior <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn> Co-authored-by: Han Wang <wang_han@iapcm.ac.cn> Co-authored-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 02bd1fc commit 2d7fdc5

4 files changed

Lines changed: 352 additions & 4 deletions

File tree

deepmd/dpmodel/utils/network.py

Lines changed: 112 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1003,7 +1003,118 @@ def deserialize(cls, data: dict) -> "FittingNet":
10031003
return FN
10041004

10051005

1006-
FittingNet = make_fitting_network(EmbeddingNet, NativeNet, NativeLayer)
1006+
class FittingNet(EmbeddingNet):
1007+
"""The fitting network. It may be implemented as an embedding
1008+
net connected with a linear output layer.
1009+
1010+
Parameters
1011+
----------
1012+
in_dim
1013+
Input dimension.
1014+
out_dim
1015+
Output dimension
1016+
neuron
1017+
The number of neurons in each hidden layer.
1018+
activation_function
1019+
The activation function.
1020+
resnet_dt
1021+
Use time step at the resnet architecture.
1022+
precision
1023+
Floating point precision for the model parameters.
1024+
bias_out
1025+
The last linear layer has bias.
1026+
seed : int, optional
1027+
Random seed.
1028+
trainable : bool or list[bool], optional
1029+
Whether the network is trainable.
1030+
"""
1031+
1032+
def __init__(
1033+
self,
1034+
in_dim: int,
1035+
out_dim: int,
1036+
neuron: list[int] = [24, 48, 96],
1037+
activation_function: str = "tanh",
1038+
resnet_dt: bool = False,
1039+
precision: str = DEFAULT_PRECISION,
1040+
bias_out: bool = True,
1041+
seed: int | list[int] | None = None,
1042+
trainable: bool | list[bool] = True,
1043+
) -> None:
1044+
if trainable is None:
1045+
trainable = [True] * (len(neuron) + 1)
1046+
elif isinstance(trainable, bool):
1047+
trainable = [trainable] * (len(neuron) + 1)
1048+
else:
1049+
pass
1050+
super().__init__(
1051+
in_dim,
1052+
neuron=neuron,
1053+
activation_function=activation_function,
1054+
resnet_dt=resnet_dt,
1055+
precision=precision,
1056+
seed=seed,
1057+
trainable=trainable[:-1],
1058+
)
1059+
i_in = neuron[-1] if len(neuron) > 0 else in_dim
1060+
i_ot = out_dim
1061+
self.layers.append(
1062+
NativeLayer(
1063+
i_in,
1064+
i_ot,
1065+
bias=bias_out,
1066+
use_timestep=False,
1067+
activation_function=None,
1068+
resnet=False,
1069+
precision=precision,
1070+
seed=child_seed(seed, len(neuron)),
1071+
trainable=trainable[-1],
1072+
)
1073+
)
1074+
self.out_dim = out_dim
1075+
self.bias_out = bias_out
1076+
1077+
def serialize(self) -> dict:
1078+
"""Serialize the network to a dict.
1079+
1080+
Returns
1081+
-------
1082+
dict
1083+
The serialized network.
1084+
"""
1085+
return {
1086+
"@class": "FittingNetwork",
1087+
"@version": 1,
1088+
"in_dim": self.in_dim,
1089+
"out_dim": self.out_dim,
1090+
"neuron": self.neuron.copy(),
1091+
"activation_function": self.activation_function,
1092+
"resnet_dt": self.resnet_dt,
1093+
"precision": self.precision,
1094+
"bias_out": self.bias_out,
1095+
"layers": [layer.serialize() for layer in self.layers],
1096+
}
1097+
1098+
@classmethod
1099+
def deserialize(cls, data: dict) -> "FittingNet":
1100+
"""Deserialize the network from a dict.
1101+
1102+
Parameters
1103+
----------
1104+
data : dict
1105+
The dict to deserialize from.
1106+
"""
1107+
data = data.copy()
1108+
check_version_compatibility(data.pop("@version", 1), 1, 1)
1109+
data.pop("@class", None)
1110+
layers = data.pop("layers")
1111+
obj = cls(**data)
1112+
# Use type(obj.layers[0]) to respect subclass layer types
1113+
layer_type = type(obj.layers[0])
1114+
obj.layers = type(obj.layers)(
1115+
[layer_type.deserialize(layer) for layer in layers]
1116+
)
1117+
return obj
10071118

10081119

10091120
class NetworkCollection:

deepmd/pt_expt/utils/network.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@
1111
NativeOP,
1212
)
1313
from deepmd.dpmodel.utils.network import EmbeddingNet as EmbeddingNetDP
14+
from deepmd.dpmodel.utils.network import FittingNet as FittingNetDP
1415
from deepmd.dpmodel.utils.network import LayerNorm as LayerNormDP
1516
from deepmd.dpmodel.utils.network import NativeLayer as NativeLayerDP
1617
from deepmd.dpmodel.utils.network import NetworkCollection as NetworkCollectionDP
1718
from deepmd.dpmodel.utils.network import (
18-
make_fitting_network,
1919
make_multilayer_network,
2020
)
2121
from deepmd.pt_expt.common import (
@@ -114,8 +114,26 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
114114
)
115115

116116

117-
class FittingNet(make_fitting_network(EmbeddingNet, NativeNet, NativeLayer)):
118-
pass
117+
class FittingNet(FittingNetDP, torch.nn.Module):
118+
def __init__(self, *args: Any, **kwargs: Any) -> None:
119+
torch.nn.Module.__init__(self)
120+
FittingNetDP.__init__(self, *args, **kwargs)
121+
# Convert dpmodel layers to pt_expt NativeLayer
122+
self.layers = torch.nn.ModuleList(
123+
[NativeLayer.deserialize(layer.serialize()) for layer in self.layers]
124+
)
125+
126+
def __call__(self, *args: Any, **kwargs: Any) -> Any:
127+
return torch.nn.Module.__call__(self, *args, **kwargs)
128+
129+
def forward(self, x: torch.Tensor) -> torch.Tensor:
130+
return self.call(x)
131+
132+
133+
register_dpmodel_mapping(
134+
FittingNetDP,
135+
lambda v: FittingNet.deserialize(v.serialize()),
136+
)
119137

120138

121139
class NetworkCollection(NetworkCollectionDP, torch.nn.Module):

source/tests/common/dpmodel/test_network.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,104 @@ def test_fitting_net(self) -> None:
313313
en1.call(inp)
314314
np.testing.assert_allclose(en0.call(inp), en1.call(inp))
315315

316+
def test_is_concrete_class(self) -> None:
317+
"""Verify FittingNet is a concrete class, not factory-generated."""
318+
in_dim = 4
319+
out_dim = 1
320+
neuron = [8, 16]
321+
net = FittingNet(
322+
in_dim=in_dim,
323+
out_dim=out_dim,
324+
neuron=neuron,
325+
activation_function="tanh",
326+
resnet_dt=True,
327+
precision="float64",
328+
bias_out=True,
329+
)
330+
# Check it's the actual FittingNet class, not a dynamic class
331+
self.assertEqual(net.__class__.__name__, "FittingNet")
332+
self.assertEqual(net.__class__.__module__, "deepmd.dpmodel.utils.network")
333+
# Verify it has the expected attributes
334+
self.assertEqual(net.in_dim, in_dim)
335+
self.assertEqual(net.out_dim, out_dim)
336+
self.assertEqual(net.neuron, neuron)
337+
self.assertEqual(net.activation_function, "tanh")
338+
self.assertEqual(net.resnet_dt, True)
339+
self.assertEqual(net.bias_out, True)
340+
# FittingNet has len(neuron) embedding layers + 1 output layer
341+
self.assertEqual(len(net.layers), len(neuron) + 1)
342+
343+
def test_forward_pass(self) -> None:
344+
"""Test FittingNet forward pass produces correct output shape."""
345+
in_dim = 4
346+
out_dim = 3
347+
neuron = [8, 16, 32]
348+
net = FittingNet(
349+
in_dim=in_dim,
350+
out_dim=out_dim,
351+
neuron=neuron,
352+
activation_function="tanh",
353+
resnet_dt=True,
354+
precision="float64",
355+
)
356+
# Single sample
357+
rng = np.random.default_rng()
358+
x = rng.standard_normal(in_dim)
359+
out = net.call(x)
360+
self.assertEqual(out.shape, (out_dim,))
361+
362+
# Batch of samples
363+
batch_size = 5
364+
x_batch = rng.standard_normal((batch_size, in_dim))
365+
out_batch = net.call(x_batch)
366+
self.assertEqual(out_batch.shape, (batch_size, out_dim))
367+
368+
def test_trainable_parameter_variants(self) -> None:
369+
"""Test FittingNet with different trainable configurations."""
370+
in_dim = 4
371+
out_dim = 2
372+
neuron = [8, 16]
373+
374+
# Test 1: All layers trainable (default)
375+
net_all_trainable = FittingNet(
376+
in_dim=in_dim,
377+
out_dim=out_dim,
378+
neuron=neuron,
379+
trainable=True,
380+
)
381+
for layer in net_all_trainable.layers:
382+
self.assertTrue(layer.trainable)
383+
384+
# Test 2: All layers frozen
385+
net_all_frozen = FittingNet(
386+
in_dim=in_dim,
387+
out_dim=out_dim,
388+
neuron=neuron,
389+
trainable=False,
390+
)
391+
for layer in net_all_frozen.layers:
392+
self.assertFalse(layer.trainable)
393+
394+
# Test 3: Mixed trainable (embedding layers frozen, output layer trainable)
395+
trainable_list = [False, False, True] # 2 embedding layers + 1 output layer
396+
net_mixed = FittingNet(
397+
in_dim=in_dim,
398+
out_dim=out_dim,
399+
neuron=neuron,
400+
trainable=trainable_list,
401+
)
402+
self.assertFalse(net_mixed.layers[0].trainable) # First embedding layer
403+
self.assertFalse(net_mixed.layers[1].trainable) # Second embedding layer
404+
self.assertTrue(net_mixed.layers[2].trainable) # Output layer
405+
406+
# Test 4: Serialize/deserialize preserves trainable
407+
serialized = net_mixed.serialize()
408+
net_restored = FittingNet.deserialize(serialized)
409+
for orig_layer, restored_layer in zip(
410+
net_mixed.layers, net_restored.layers, strict=True
411+
):
412+
self.assertEqual(orig_layer.trainable, restored_layer.trainable)
413+
316414

317415
class TestNetworkCollection(unittest.TestCase):
318416
def setUp(self) -> None:

source/tests/pt_expt/utils/test_network.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,3 +281,124 @@ def test_trainable_parameter_handling(self) -> None:
281281
for layer in net_frozen.layers:
282282
if layer.w is not None:
283283
self.assertFalse(layer.w.requires_grad)
284+
285+
286+
class TestFittingNetRefactor(unittest.TestCase):
287+
"""Tests for the refactored FittingNet pt_expt wrapper."""
288+
289+
def setUp(self) -> None:
290+
self.in_dim = 4
291+
self.out_dim = 1
292+
self.neuron = [8, 16]
293+
self.activation = "tanh"
294+
self.resnet_dt = True
295+
self.precision = "float64"
296+
297+
def test_pt_expt_fitting_net_wraps_dpmodel(self) -> None:
298+
"""Verify pt_expt FittingNet correctly wraps dpmodel."""
299+
from deepmd.pt_expt.utils.network import (
300+
FittingNet,
301+
)
302+
303+
net = FittingNet(
304+
in_dim=self.in_dim,
305+
out_dim=self.out_dim,
306+
neuron=self.neuron,
307+
activation_function=self.activation,
308+
resnet_dt=self.resnet_dt,
309+
precision=self.precision,
310+
seed=GLOBAL_SEED,
311+
)
312+
# Check it's a torch.nn.Module
313+
self.assertIsInstance(net, torch.nn.Module)
314+
# Check layers are converted to pt_expt NativeLayer (torch modules)
315+
self.assertIsInstance(net.layers, torch.nn.ModuleList)
316+
for layer in net.layers:
317+
self.assertIsInstance(layer, torch.nn.Module)
318+
319+
def test_pt_expt_fitting_net_forward(self) -> None:
320+
"""Test pt_expt FittingNet forward pass returns torch.Tensor."""
321+
from deepmd.pt_expt.utils.network import (
322+
FittingNet,
323+
)
324+
325+
net = FittingNet(
326+
in_dim=self.in_dim,
327+
out_dim=self.out_dim,
328+
neuron=self.neuron,
329+
activation_function=self.activation,
330+
resnet_dt=self.resnet_dt,
331+
precision=self.precision,
332+
seed=GLOBAL_SEED,
333+
)
334+
x = torch.randn(5, self.in_dim, dtype=torch.float64, device=env.DEVICE)
335+
out = net(x)
336+
self.assertIsInstance(out, torch.Tensor)
337+
self.assertEqual(out.shape, (5, self.out_dim))
338+
self.assertEqual(out.dtype, torch.float64)
339+
340+
def test_serialization_round_trip_pt_expt(self) -> None:
341+
"""Test pt_expt FittingNet serialization/deserialization."""
342+
from deepmd.pt_expt.utils.network import (
343+
FittingNet,
344+
)
345+
346+
net = FittingNet(
347+
in_dim=self.in_dim,
348+
out_dim=self.out_dim,
349+
neuron=self.neuron,
350+
activation_function=self.activation,
351+
resnet_dt=self.resnet_dt,
352+
precision=self.precision,
353+
seed=GLOBAL_SEED,
354+
)
355+
x = torch.randn(5, self.in_dim, dtype=torch.float64, device=env.DEVICE)
356+
out1 = net(x)
357+
358+
# Serialize and deserialize
359+
serialized = net.serialize()
360+
net2 = FittingNet.deserialize(serialized)
361+
362+
# Verify layers are still pt_expt NativeLayer modules
363+
self.assertIsInstance(net2.layers, torch.nn.ModuleList)
364+
for layer in net2.layers:
365+
self.assertIsInstance(layer, torch.nn.Module)
366+
367+
out2 = net2(x)
368+
np.testing.assert_allclose(
369+
out1.detach().cpu().numpy(),
370+
out2.detach().cpu().numpy(),
371+
)
372+
373+
def test_registry_converts_dpmodel_to_pt_expt(self) -> None:
374+
"""Test that dpmodel FittingNet can be converted to pt_expt via registry."""
375+
from deepmd.dpmodel.utils.network import FittingNet as DPFittingNet
376+
from deepmd.pt_expt.common import (
377+
try_convert_module,
378+
)
379+
from deepmd.pt_expt.utils.network import (
380+
FittingNet,
381+
)
382+
383+
# Create dpmodel FittingNet
384+
dp_net = DPFittingNet(
385+
in_dim=self.in_dim,
386+
out_dim=self.out_dim,
387+
neuron=self.neuron,
388+
activation_function=self.activation,
389+
resnet_dt=self.resnet_dt,
390+
precision=self.precision,
391+
seed=GLOBAL_SEED,
392+
)
393+
394+
# Try to convert via registry
395+
converted = try_convert_module(dp_net)
396+
397+
# Should return pt_expt FittingNet
398+
self.assertIsNotNone(converted)
399+
self.assertIsInstance(converted, torch.nn.Module)
400+
self.assertIsInstance(converted, FittingNet)
401+
402+
# Verify layers are pt_expt modules
403+
for layer in converted.layers:
404+
self.assertIsInstance(layer, torch.nn.Module)

0 commit comments

Comments
 (0)