Skip to content

Commit 7cf88d2

Browse files
authored
Merge branch 'master' into refact-fitting-net
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
2 parents ad83d98 + 97d8ded commit 7cf88d2

18 files changed

Lines changed: 85 additions & 20 deletions

File tree

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ repos:
6363
rev: v21.1.8
6464
hooks:
6565
- id: clang-format
66-
exclude: ^(source/3rdparty|source/lib/src/gpu/cudart/.+\.inc|.+\.ipynb$|.+\.json$)
66+
exclude: ^(source/3rdparty|source/lib/src/gpu/cudart/.+\.inc|.+\.ipynb$|source/tests/infer/.+\.json$)
6767
# markdown, yaml, CSS, javascript
6868
- repo: https://github.com/pre-commit/mirrors-prettier
6969
rev: v4.0.0-alpha.8

deepmd/pt_expt/common.py

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,14 @@
2525
import numpy as np
2626
import torch
2727

28+
from deepmd.dpmodel.common import (
29+
NativeOP,
30+
)
31+
2832
# ---------------------------------------------------------------------------
2933
# dpmodel → pt_expt converter registry
3034
# ---------------------------------------------------------------------------
31-
_DPMODEL_TO_PT_EXPT: dict[type, Callable[[Any], torch.nn.Module]] = {}
35+
_DPMODEL_TO_PT_EXPT: dict[type[NativeOP], Callable[[NativeOP], torch.nn.Module]] = {}
3236
"""Registry mapping dpmodel classes to their pt_expt converter functions.
3337
3438
This registry is populated at module import time via `register_dpmodel_mapping`
@@ -43,7 +47,7 @@
4347

4448

4549
def register_dpmodel_mapping(
46-
dpmodel_cls: type, converter: Callable[[Any], torch.nn.Module]
50+
dpmodel_cls: type[NativeOP], converter: Callable[[NativeOP], torch.nn.Module]
4751
) -> None:
4852
"""Register a converter that turns a dpmodel instance into a pt_expt Module.
4953
@@ -54,10 +58,10 @@ def register_dpmodel_mapping(
5458
5559
Parameters
5660
----------
57-
dpmodel_cls : type
61+
dpmodel_cls : type[NativeOP]
5862
The dpmodel class to register (e.g., AtomExcludeMaskDP, NetworkCollectionDP).
5963
This is the key used for lookup in dpmodel_setattr.
60-
converter : Callable[[Any], torch.nn.Module]
64+
converter : Callable[[NativeOP], torch.nn.Module]
6165
A callable that converts a dpmodel instance to a pt_expt module.
6266
Common patterns:
6367
- Reconstruct from constructor args: lambda v: PtExptClass(v.ntypes, ...)
@@ -85,7 +89,7 @@ def register_dpmodel_mapping(
8589
def try_convert_module(value: Any) -> torch.nn.Module | None:
8690
"""Convert a dpmodel object to its pt_expt wrapper if a converter is registered.
8791
88-
This function looks up the type of *value* in the _DPMODEL_TO_PT_EXPT
92+
This function looks up the exact type of *value* in the _DPMODEL_TO_PT_EXPT
8993
registry. If a converter is found, it invokes it to produce a torch.nn.Module
9094
wrapper; otherwise it returns None.
9195
@@ -103,8 +107,9 @@ def try_convert_module(value: Any) -> torch.nn.Module | None:
103107
104108
Notes
105109
-----
106-
This function uses exact type matching. Each dpmodel class must be explicitly
107-
registered via register_dpmodel_mapping.
110+
This function uses exact type matching (not isinstance checks) to ensure
111+
predictable behavior. Each dpmodel class must be explicitly registered via
112+
register_dpmodel_mapping.
108113
109114
The function is called by dpmodel_setattr when it encounters an object that
110115
might be a dpmodel instance. If conversion succeeds, the caller should use
@@ -211,9 +216,19 @@ def dpmodel_setattr(obj: torch.nn.Module, name: str, value: Any) -> tuple[bool,
211216

212217
# dpmodel object → pt_expt module
213218
if "_modules" in obj.__dict__:
214-
converted = try_convert_module(value)
215-
if converted is not None:
216-
return False, converted
219+
# Try to convert dpmodel objects that aren't already torch.nn.Modules
220+
if not isinstance(value, torch.nn.Module):
221+
converted = try_convert_module(value)
222+
if converted is not None:
223+
return False, converted
224+
# If this is a NativeOP that should have been registered but wasn't, raise error
225+
if isinstance(value, NativeOP):
226+
raise TypeError(
227+
f"Attempted to assign a dpmodel object of type {type(value).__name__} "
228+
f"but no converter is registered. Please call register_dpmodel_mapping "
229+
f"for this type. If this object doesn't need conversion, register it "
230+
f"with an identity converter: lambda v: v"
231+
)
217232

218233
return False, value
219234

@@ -275,3 +290,18 @@ def to_torch_array(array: Any) -> torch.Tensor | None:
275290
if torch.is_tensor(array):
276291
return array.to(device=env.DEVICE)
277292
return torch.as_tensor(array, device=env.DEVICE)
293+
294+
295+
# Import utils to trigger dpmodel→pt_expt converter registrations
296+
# This must happen after the functions above are defined to avoid circular imports
297+
def _ensure_registrations() -> None:
298+
"""Import pt_expt.utils modules to register converters.
299+
300+
This function is called on module import to ensure all dpmodel→pt_expt
301+
converters are registered before any descriptors/fittings try to use them.
302+
"""
303+
# Import triggers registration of NetworkCollection, ExcludeMask, EnvMat
304+
from deepmd.pt_expt import utils # noqa: F401
305+
306+
307+
_ensure_registrations()

deepmd/pt_expt/utils/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22

3+
from deepmd.dpmodel.utils.env_mat import (
4+
EnvMat,
5+
)
6+
from deepmd.pt_expt.common import (
7+
register_dpmodel_mapping,
8+
)
9+
310
from .exclude_mask import (
411
AtomExcludeMask,
512
PairExcludeMask,
@@ -11,6 +18,10 @@
1118
TypeEmbedNet,
1219
)
1320

21+
# Register EnvMat with identity converter - it doesn't need wrapping
22+
# as it's a stateless utility class
23+
register_dpmodel_mapping(EnvMat, lambda v: v)
24+
1425
__all__ = [
1526
"AtomExcludeMask",
1627
"NetworkCollection",

examples/water_multi_task/pytorch_example/input_torch_with_alias.json

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,20 +58,31 @@
5858
"type_map": "type_map_all",
5959
"descriptor": "dpa3_descriptor",
6060
"fitting_net": "shared_fit_with_id",
61-
"model_branch_alias": ["Default","Water"],
61+
"model_branch_alias": [
62+
"Default",
63+
"Water"
64+
],
6265
"info": {
63-
"description": "Water model with DPA3 descriptor and shared fitting net",
64-
"observed_type": ["H", "O"]
66+
"description": "Water model with DPA3 descriptor and shared fitting net",
67+
"observed_type": [
68+
"H",
69+
"O"
70+
]
6571
}
6672
},
6773
"water_2": {
6874
"type_map": "type_map_all",
6975
"descriptor": "dpa3_descriptor",
7076
"fitting_net": "shared_fit_with_id",
71-
"model_branch_alias": ["Water2"],
77+
"model_branch_alias": [
78+
"Water2"
79+
],
7280
"info": {
73-
"description": "Water duplicated model with DPA3 descriptor and shared fitting net",
74-
"observed_type": ["H", "O"]
81+
"description": "Water duplicated model with DPA3 descriptor and shared fitting net",
82+
"observed_type": [
83+
"H",
84+
"O"
85+
]
7586
}
7687
}
7788
}

source/tests/consistent/descriptor/test_dpa1.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ def data(self) -> dict:
127127
"type_map": ["O", "H"] if use_econf_tebd else None,
128128
"seed": 1145141919810,
129129
"trainable": False,
130+
"activation_function": "relu",
130131
}
131132

132133
def is_meaningless_zero_attention_layer_tests(

source/tests/consistent/descriptor/test_dpa2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def data(self) -> dict:
132132
"tebd_dim": 4,
133133
"tebd_input_mode": repinit_tebd_input_mode,
134134
"set_davg_zero": repinit_set_davg_zero,
135-
"activation_function": "tanh",
135+
"activation_function": "relu",
136136
"type_one_side": repinit_type_one_side,
137137
"use_three_body": repinit_use_three_body,
138138
"three_body_sel": 8,
@@ -163,7 +163,7 @@ def data(self) -> dict:
163163
"attn2_hidden": 10,
164164
"attn2_nhead": 2,
165165
"attn2_has_gate": repformer_attn2_has_gate,
166-
"activation_function": "tanh",
166+
"activation_function": "relu",
167167
"update_style": repformer_update_style,
168168
"update_residual": 0.001,
169169
"update_residual_init": repformer_update_residual_init,

source/tests/consistent/descriptor/test_dpa3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def data(self) -> dict:
125125
}
126126
),
127127
# kwargs for descriptor
128-
"activation_function": "silu",
128+
"activation_function": "relu",
129129
"precision": precision,
130130
"exclude_types": exclude_types,
131131
"env_protection": 0.0,

source/tests/consistent/descriptor/test_hybrid.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def data(self) -> dict:
6161
"type_one_side": True,
6262
"precision": "float64",
6363
"seed": 20240229,
64+
"activation_function": "relu",
6465
},
6566
{
6667
"type": "se_e2_a",
@@ -73,6 +74,7 @@ def data(self) -> dict:
7374
"type_one_side": True,
7475
"precision": "float64",
7576
"seed": 20240229,
77+
"activation_function": "relu",
7678
},
7779
]
7880
}

source/tests/consistent/descriptor/test_se_atten_v2.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ def data(self) -> dict:
124124
"use_tebd_bias": use_tebd_bias,
125125
"type_map": ["O", "H"] if use_econf_tebd else None,
126126
"seed": 1145141919810,
127+
"activation_function": "relu",
127128
}
128129

129130
def is_meaningless_zero_attention_layer_tests(

source/tests/consistent/descriptor/test_se_e2_a.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ def data(self) -> dict:
9999
"env_protection": env_protection,
100100
"precision": precision,
101101
"seed": 1145141919810,
102+
"activation_function": "relu",
102103
}
103104

104105
@property

0 commit comments

Comments
 (0)