Skip to content

Commit 97d8ded

Browse files
wanghan-iapcmHan Wang
andauthored
refact (pt_expt): provide infrastructure for converting dpmodel classes to PyTorch modules. (#5204)
consider after the merge of #5194 automatically wrapping dpmodel classes (array_api_compat-based) as PyTorch modules. The key insight is to detect attributes by their **value type** rather than by hard-coded names. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Registry-driven conversion for DP objects to PyTorch modules enabling automatic wrapper creation. * New PyTorch-friendly descriptor variants with stable forward outputs for se_e2_a and se_r. * PyTorch-wrapped exclude-mask utilities and a NetworkCollection of wrapped network types for proper module/state handling. * Device-aware tensor conversion and robust handling of numpy buffers and None-valued buffers for reliable serialization/movement. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Han Wang <wang_han@iapcm.ac.cn>
1 parent 03fccc4 commit 97d8ded

6 files changed

Lines changed: 332 additions & 87 deletions

File tree

deepmd/pt_expt/common.py

Lines changed: 277 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,22 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
"""Common utilities for the pt_expt backend.
3+
4+
This module provides the core infrastructure for automatically wrapping dpmodel
5+
classes (array_api_compat-based) as PyTorch modules. The key insight is to
6+
detect attributes by their **value type** rather than by hard-coded names:
7+
8+
- numpy arrays → torch buffers (persistent state like statistics, masks)
9+
- dpmodel objects → pt_expt torch.nn.Module wrappers (via registry lookup)
10+
- None values → clear existing buffers
11+
12+
This eliminates the need to manually enumerate attribute names in each wrapper's
13+
__setattr__ method, making the codebase more maintainable when dpmodel adds
14+
new attributes.
15+
"""
16+
17+
from collections.abc import (
18+
Callable,
19+
)
220
from typing import (
321
Any,
422
overload,
@@ -7,11 +25,217 @@
725
import numpy as np
826
import torch
927

10-
from deepmd.pt_expt.utils import (
11-
env,
28+
from deepmd.dpmodel.common import (
29+
NativeOP,
1230
)
1331

32+
# ---------------------------------------------------------------------------
33+
# dpmodel → pt_expt converter registry
34+
# ---------------------------------------------------------------------------
35+
_DPMODEL_TO_PT_EXPT: dict[type[NativeOP], Callable[[NativeOP], torch.nn.Module]] = {}
36+
"""Registry mapping dpmodel classes to their pt_expt converter functions.
37+
38+
This registry is populated at module import time via `register_dpmodel_mapping`
39+
calls in each pt_expt wrapper module (e.g., exclude_mask.py, network.py). When
40+
dpmodel_setattr encounters a dpmodel object, it looks up the object's type in
41+
this registry to find the appropriate converter.
42+
43+
Examples of registered mappings:
44+
- AtomExcludeMaskDP → lambda v: AtomExcludeMask(v.ntypes, exclude_types=...)
45+
- NetworkCollectionDP → lambda v: NetworkCollection.deserialize(v.serialize())
46+
"""
47+
48+
49+
def register_dpmodel_mapping(
50+
dpmodel_cls: type[NativeOP], converter: Callable[[NativeOP], torch.nn.Module]
51+
) -> None:
52+
"""Register a converter that turns a dpmodel instance into a pt_expt Module.
53+
54+
This function is called at module import time by each pt_expt wrapper to
55+
register how dpmodel objects should be converted when they're assigned as
56+
attributes. The converter is a callable that takes a dpmodel instance and
57+
returns the corresponding pt_expt torch.nn.Module wrapper.
58+
59+
Parameters
60+
----------
61+
dpmodel_cls : type[NativeOP]
62+
The dpmodel class to register (e.g., AtomExcludeMaskDP, NetworkCollectionDP).
63+
This is the key used for lookup in dpmodel_setattr.
64+
converter : Callable[[NativeOP], torch.nn.Module]
65+
A callable that converts a dpmodel instance to a pt_expt module.
66+
Common patterns:
67+
- Reconstruct from constructor args: lambda v: PtExptClass(v.ntypes, ...)
68+
- Round-trip via serialization: lambda v: PtExptClass.deserialize(v.serialize())
69+
70+
Notes
71+
-----
72+
This function must be called AFTER the pt_expt wrapper class is defined but
73+
BEFORE dpmodel_setattr might encounter instances of dpmodel_cls. In practice,
74+
this means calling it immediately after the wrapper class definition at module
75+
import time.
76+
77+
Examples
78+
--------
79+
>>> register_dpmodel_mapping(
80+
... AtomExcludeMaskDP,
81+
... lambda v: AtomExcludeMask(
82+
... v.ntypes, exclude_types=list(v.get_exclude_types())
83+
... ),
84+
... )
85+
"""
86+
_DPMODEL_TO_PT_EXPT[dpmodel_cls] = converter
87+
88+
89+
def try_convert_module(value: Any) -> torch.nn.Module | None:
90+
"""Convert a dpmodel object to its pt_expt wrapper if a converter is registered.
91+
92+
This function looks up the exact type of *value* in the _DPMODEL_TO_PT_EXPT
93+
registry. If a converter is found, it invokes it to produce a torch.nn.Module
94+
wrapper; otherwise it returns None.
95+
96+
Parameters
97+
----------
98+
value : Any
99+
The value to potentially convert. Typically a dpmodel object like
100+
AtomExcludeMaskDP or NetworkCollectionDP.
101+
102+
Returns
103+
-------
104+
torch.nn.Module or None
105+
The converted pt_expt module if a converter is registered for value's
106+
type, otherwise None.
107+
108+
Notes
109+
-----
110+
This function uses exact type matching (not isinstance checks) to ensure
111+
predictable behavior. Each dpmodel class must be explicitly registered via
112+
register_dpmodel_mapping.
113+
114+
The function is called by dpmodel_setattr when it encounters an object that
115+
might be a dpmodel instance. If conversion succeeds, the caller should use
116+
the converted module instead of the original value.
117+
"""
118+
converter = _DPMODEL_TO_PT_EXPT.get(type(value))
119+
if converter is not None:
120+
return converter(value)
121+
return None
122+
123+
124+
def dpmodel_setattr(obj: torch.nn.Module, name: str, value: Any) -> tuple[bool, Any]:
125+
"""Common __setattr__ logic for pt_expt wrappers around dpmodel classes.
126+
127+
This function implements automatic attribute detection by value type, eliminating
128+
the need to hard-code attribute names in each wrapper's __setattr__ method. It
129+
handles three cases:
130+
131+
1. **numpy arrays → torch buffers**: Persistent state like statistics (davg, dstd)
132+
or masks that should be saved in state_dict and moved with .to(device).
133+
2. **None values → clear buffers**: Setting an existing buffer to None.
134+
3. **dpmodel objects → pt_expt modules**: Nested dpmodel objects like
135+
AtomExcludeMaskDP or NetworkCollectionDP are converted to their pt_expt
136+
wrappers via the registry.
137+
138+
Parameters
139+
----------
140+
obj : torch.nn.Module
141+
The pt_expt wrapper object whose attribute is being set. Must be a
142+
torch.nn.Module (caller should verify this).
143+
name : str
144+
The attribute name being set.
145+
value : Any
146+
The value being assigned. This function inspects the type to determine
147+
how to handle it.
14148
149+
Returns
150+
-------
151+
handled : bool
152+
True if the attribute has been fully set (caller should NOT call
153+
super().__setattr__). False if the caller should forward the (possibly
154+
converted) value to super().__setattr__(name, value).
155+
value : Any
156+
The value to use. May be converted (e.g., dpmodel object → pt_expt module)
157+
or unchanged (e.g., scalar, list, or unregistered object).
158+
159+
Notes
160+
-----
161+
**Why this design is safe:**
162+
163+
- In dpmodel, all persistent arrays use `self.xxx = np.array(...)`. Scalars
164+
use `.item()`, lists use `.tolist()`. So `isinstance(value, np.ndarray)`
165+
reliably identifies buffer-worthy attributes.
166+
- torch.Tensor values assigned to existing buffers fall through to
167+
torch.nn.Module.__setattr__, which correctly updates them.
168+
- dpmodel objects are identified by registry lookup (exact type match), so
169+
only explicitly registered types are converted.
170+
- The function checks `"_buffers" in obj.__dict__` to ensure the object has
171+
been initialized as a torch.nn.Module before attempting buffer operations.
172+
173+
**Circular import resolution:**
174+
175+
The function uses a deferred import `from deepmd.pt_expt.utils import env`
176+
inside the function body. This breaks the circular dependency chain:
177+
common.py → utils/__init__.py → exclude_mask.py → common.py. The import is
178+
cached by Python after the first call, so there's no performance penalty.
179+
180+
**Usage pattern:**
181+
182+
Typical wrapper classes use this three-line pattern:
183+
184+
>>> class MyWrapper(MyDPModel, torch.nn.Module):
185+
... def __setattr__(self, name, value):
186+
... handled, value = dpmodel_setattr(self, name, value)
187+
... if not handled:
188+
... super().__setattr__(name, value)
189+
190+
Examples
191+
--------
192+
>>> # Case 1: numpy array → buffer
193+
>>> obj.davg = np.array([1.0, 2.0]) # becomes torch.Tensor buffer
194+
>>>
195+
>>> # Case 2: clear buffer
196+
>>> obj.davg = None # sets buffer to None
197+
>>>
198+
>>> # Case 3: dpmodel object → pt_expt module
199+
>>> obj.emask = AtomExcludeMaskDP(...) # becomes AtomExcludeMask module
200+
"""
201+
from deepmd.pt_expt.utils import env # deferred - avoids circular import
202+
203+
# numpy array → torch buffer
204+
if isinstance(value, np.ndarray) and "_buffers" in obj.__dict__:
205+
tensor = torch.as_tensor(value, device=env.DEVICE)
206+
if name in obj._buffers:
207+
obj._buffers[name] = tensor
208+
return True, tensor
209+
obj.register_buffer(name, tensor)
210+
return True, tensor
211+
212+
# clear an existing buffer to None
213+
if value is None and "_buffers" in obj.__dict__ and name in obj._buffers:
214+
obj._buffers[name] = None
215+
return True, None
216+
217+
# dpmodel object → pt_expt module
218+
if "_modules" in obj.__dict__:
219+
# Try to convert dpmodel objects that aren't already torch.nn.Modules
220+
if not isinstance(value, torch.nn.Module):
221+
converted = try_convert_module(value)
222+
if converted is not None:
223+
return False, converted
224+
# If this is a NativeOP that should have been registered but wasn't, raise error
225+
if isinstance(value, NativeOP):
226+
raise TypeError(
227+
f"Attempted to assign a dpmodel object of type {type(value).__name__} "
228+
f"but no converter is registered. Please call register_dpmodel_mapping "
229+
f"for this type. If this object doesn't need conversion, register it "
230+
f"with an identity converter: lambda v: v"
231+
)
232+
233+
return False, value
234+
235+
236+
# ---------------------------------------------------------------------------
237+
# Utility
238+
# ---------------------------------------------------------------------------
15239
@overload
16240
def to_torch_array(array: np.ndarray) -> torch.Tensor: ...
17241

@@ -25,9 +249,59 @@ def to_torch_array(array: torch.Tensor) -> torch.Tensor: ...
25249

26250

27251
def to_torch_array(array: Any) -> torch.Tensor | None:
28-
"""Convert input to a torch tensor on the pt-expt device."""
252+
"""Convert input to a torch tensor on the pt_expt device.
253+
254+
This utility function handles conversion from various array-like types (numpy
255+
arrays, torch tensors on different devices, etc.) to torch tensors on the
256+
pt_expt backend's configured device.
257+
258+
Parameters
259+
----------
260+
array : Any
261+
The input to convert. Can be:
262+
- None (returns None)
263+
- torch.Tensor (moves to pt_expt device)
264+
- numpy array or array-like (converts to torch.Tensor on pt_expt device)
265+
266+
Returns
267+
-------
268+
torch.Tensor or None
269+
The input as a torch tensor on the pt_expt device (env.DEVICE), or None
270+
if the input was None.
271+
272+
Notes
273+
-----
274+
This function uses the same deferred import pattern as dpmodel_setattr to
275+
avoid circular dependencies. The env module determines the target device
276+
(typically CPU for pt_expt).
277+
278+
Examples
279+
--------
280+
>>> import numpy as np
281+
>>> arr = np.array([1.0, 2.0, 3.0])
282+
>>> tensor = to_torch_array(arr)
283+
>>> tensor.device
284+
device(type='cpu') # or whatever env.DEVICE is set to
285+
"""
286+
from deepmd.pt_expt.utils import env # deferred - avoids circular import
287+
29288
if array is None:
30289
return None
31290
if torch.is_tensor(array):
32291
return array.to(device=env.DEVICE)
33292
return torch.as_tensor(array, device=env.DEVICE)
293+
294+
295+
# Import utils to trigger dpmodel→pt_expt converter registrations
296+
# This must happen after the functions above are defined to avoid circular imports
297+
def _ensure_registrations() -> None:
298+
"""Import pt_expt.utils modules to register converters.
299+
300+
This function is called on module import to ensure all dpmodel→pt_expt
301+
converters are registered before any descriptors/fittings try to use them.
302+
"""
303+
# Import triggers registration of NetworkCollection, ExcludeMask, EnvMat
304+
from deepmd.pt_expt import utils # noqa: F401
305+
306+
307+
_ensure_registrations()

deepmd/pt_expt/descriptor/se_e2_a.py

Lines changed: 6 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,12 @@
66
import torch
77

88
from deepmd.dpmodel.descriptor.se_e2_a import DescrptSeAArrayAPI as DescrptSeADP
9+
from deepmd.pt_expt.common import (
10+
dpmodel_setattr,
11+
)
912
from deepmd.pt_expt.descriptor.base_descriptor import (
1013
BaseDescriptor,
1114
)
12-
from deepmd.pt_expt.utils import (
13-
env,
14-
)
15-
from deepmd.pt_expt.utils.exclude_mask import (
16-
PairExcludeMask,
17-
)
18-
from deepmd.pt_expt.utils.network import (
19-
NetworkCollection,
20-
)
2115

2216

2317
@BaseDescriptor.register("se_e2_a_expt")
@@ -32,30 +26,9 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any:
3226
return torch.nn.Module.__call__(self, *args, **kwargs)
3327

3428
def __setattr__(self, name: str, value: Any) -> None:
35-
if name in {"davg", "dstd"} and "_buffers" in self.__dict__:
36-
tensor = (
37-
None if value is None else torch.as_tensor(value, device=env.DEVICE)
38-
)
39-
if name in self._buffers:
40-
self._buffers[name] = tensor
41-
return
42-
# Register on first assignment so buffers are in state_dict and moved by .to().
43-
self.register_buffer(name, tensor)
44-
return
45-
if name == "embeddings" and "_modules" in self.__dict__:
46-
if value is not None and not isinstance(value, torch.nn.Module):
47-
if hasattr(value, "serialize"):
48-
value = NetworkCollection.deserialize(value.serialize())
49-
elif isinstance(value, dict):
50-
value = NetworkCollection.deserialize(value)
51-
return super().__setattr__(name, value)
52-
if name == "emask" and "_modules" in self.__dict__:
53-
if value is not None and not isinstance(value, torch.nn.Module):
54-
value = PairExcludeMask(
55-
self.ntypes, exclude_types=list(value.get_exclude_types())
56-
)
57-
return super().__setattr__(name, value)
58-
return super().__setattr__(name, value)
29+
handled, value = dpmodel_setattr(self, name, value)
30+
if not handled:
31+
super().__setattr__(name, value)
5932

6033
def forward(
6134
self,

0 commit comments

Comments
 (0)