Skip to content

Commit d3b01da

Browse files
author
Han Wang
committed
utility to handel dpmodel -> pt_expt conversion
1 parent 8bdb1f8 commit d3b01da

6 files changed

Lines changed: 293 additions & 88 deletions

File tree

deepmd/pt_expt/common.py

Lines changed: 249 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,22 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
"""Common utilities for the pt_expt backend.
3+
4+
This module provides the core infrastructure for automatically wrapping dpmodel
5+
classes (array_api_compat-based) as PyTorch modules. The key insight is to
6+
detect attributes by their **value type** rather than by hard-coded names:
7+
8+
- numpy arrays → torch buffers (persistent state like statistics, masks)
9+
- dpmodel objects → pt_expt torch.nn.Module wrappers (via registry lookup)
10+
- None values → clear existing buffers
11+
12+
This eliminates the need to manually enumerate attribute names in each wrapper's
13+
__setattr__ method, making the codebase more maintainable when dpmodel adds
14+
new attributes.
15+
"""
16+
17+
from collections.abc import (
18+
Callable,
19+
)
220
from typing import (
321
Any,
422
overload,
@@ -7,11 +25,203 @@
725
import numpy as np
826
import torch
927

10-
from deepmd.pt_expt.utils import (
11-
env,
12-
)
28+
# ---------------------------------------------------------------------------
29+
# dpmodel → pt_expt converter registry
30+
# ---------------------------------------------------------------------------
31+
_DPMODEL_TO_PT_EXPT: dict[type, Callable[[Any], torch.nn.Module]] = {}
32+
"""Registry mapping dpmodel classes to their pt_expt converter functions.
33+
34+
This registry is populated at module import time via `register_dpmodel_mapping`
35+
calls in each pt_expt wrapper module (e.g., exclude_mask.py, network.py). When
36+
dpmodel_setattr encounters a dpmodel object, it looks up the object's type in
37+
this registry to find the appropriate converter.
38+
39+
Examples of registered mappings:
40+
- AtomExcludeMaskDP → lambda v: AtomExcludeMask(v.ntypes, exclude_types=...)
41+
- NetworkCollectionDP → lambda v: NetworkCollection.deserialize(v.serialize())
42+
"""
43+
44+
45+
def register_dpmodel_mapping(
46+
dpmodel_cls: type, converter: Callable[[Any], torch.nn.Module]
47+
) -> None:
48+
"""Register a converter that turns a dpmodel instance into a pt_expt Module.
49+
50+
This function is called at module import time by each pt_expt wrapper to
51+
register how dpmodel objects should be converted when they're assigned as
52+
attributes. The converter is a callable that takes a dpmodel instance and
53+
returns the corresponding pt_expt torch.nn.Module wrapper.
54+
55+
Parameters
56+
----------
57+
dpmodel_cls : type
58+
The dpmodel class to register (e.g., AtomExcludeMaskDP, NetworkCollectionDP).
59+
This is the key used for lookup in dpmodel_setattr.
60+
converter : Callable[[Any], torch.nn.Module]
61+
A callable that converts a dpmodel instance to a pt_expt module.
62+
Common patterns:
63+
- Reconstruct from constructor args: lambda v: PtExptClass(v.ntypes, ...)
64+
- Round-trip via serialization: lambda v: PtExptClass.deserialize(v.serialize())
65+
66+
Notes
67+
-----
68+
This function must be called AFTER the pt_expt wrapper class is defined but
69+
BEFORE dpmodel_setattr might encounter instances of dpmodel_cls. In practice,
70+
this means calling it immediately after the wrapper class definition at module
71+
import time.
72+
73+
Examples
74+
--------
75+
>>> register_dpmodel_mapping(
76+
... AtomExcludeMaskDP,
77+
... lambda v: AtomExcludeMask(
78+
... v.ntypes, exclude_types=list(v.get_exclude_types())
79+
... ),
80+
... )
81+
"""
82+
_DPMODEL_TO_PT_EXPT[dpmodel_cls] = converter
83+
84+
85+
def try_convert_module(value: Any) -> torch.nn.Module | None:
86+
"""Convert a dpmodel object to its pt_expt wrapper if a converter is registered.
87+
88+
This function looks up the exact type of *value* in the _DPMODEL_TO_PT_EXPT
89+
registry. If a converter is found, it invokes it to produce a torch.nn.Module
90+
wrapper; otherwise it returns None.
91+
92+
Parameters
93+
----------
94+
value : Any
95+
The value to potentially convert. Typically a dpmodel object like
96+
AtomExcludeMaskDP or NetworkCollectionDP.
97+
98+
Returns
99+
-------
100+
torch.nn.Module or None
101+
The converted pt_expt module if a converter is registered for value's
102+
type, otherwise None.
103+
104+
Notes
105+
-----
106+
This function uses exact type matching (not isinstance checks) to ensure
107+
predictable behavior. Each dpmodel class must be explicitly registered via
108+
register_dpmodel_mapping.
109+
110+
The function is called by dpmodel_setattr when it encounters an object that
111+
might be a dpmodel instance. If conversion succeeds, the caller should use
112+
the converted module instead of the original value.
113+
"""
114+
converter = _DPMODEL_TO_PT_EXPT.get(type(value))
115+
if converter is not None:
116+
return converter(value)
117+
return None
118+
119+
120+
def dpmodel_setattr(obj: torch.nn.Module, name: str, value: Any) -> tuple[bool, Any]:
121+
"""Common __setattr__ logic for pt_expt wrappers around dpmodel classes.
13122
123+
This function implements automatic attribute detection by value type, eliminating
124+
the need to hard-code attribute names in each wrapper's __setattr__ method. It
125+
handles three cases:
14126
127+
1. **numpy arrays → torch buffers**: Persistent state like statistics (davg, dstd)
128+
or masks that should be saved in state_dict and moved with .to(device).
129+
2. **None values → clear buffers**: Setting an existing buffer to None.
130+
3. **dpmodel objects → pt_expt modules**: Nested dpmodel objects like
131+
AtomExcludeMaskDP or NetworkCollectionDP are converted to their pt_expt
132+
wrappers via the registry.
133+
134+
Parameters
135+
----------
136+
obj : torch.nn.Module
137+
The pt_expt wrapper object whose attribute is being set. Must be a
138+
torch.nn.Module (caller should verify this).
139+
name : str
140+
The attribute name being set.
141+
value : Any
142+
The value being assigned. This function inspects the type to determine
143+
how to handle it.
144+
145+
Returns
146+
-------
147+
handled : bool
148+
True if the attribute has been fully set (caller should NOT call
149+
super().__setattr__). False if the caller should forward the (possibly
150+
converted) value to super().__setattr__(name, value).
151+
value : Any
152+
The value to use. May be converted (e.g., dpmodel object → pt_expt module)
153+
or unchanged (e.g., scalar, list, or unregistered object).
154+
155+
Notes
156+
-----
157+
**Why this design is safe:**
158+
159+
- In dpmodel, all persistent arrays use `self.xxx = np.array(...)`. Scalars
160+
use `.item()`, lists use `.tolist()`. So `isinstance(value, np.ndarray)`
161+
reliably identifies buffer-worthy attributes.
162+
- torch.Tensor values assigned to existing buffers fall through to
163+
torch.nn.Module.__setattr__, which correctly updates them.
164+
- dpmodel objects are identified by registry lookup (exact type match), so
165+
only explicitly registered types are converted.
166+
- The function checks `"_buffers" in obj.__dict__` to ensure the object has
167+
been initialized as a torch.nn.Module before attempting buffer operations.
168+
169+
**Circular import resolution:**
170+
171+
The function uses a deferred import `from deepmd.pt_expt.utils import env`
172+
inside the function body. This breaks the circular dependency chain:
173+
common.py → utils/__init__.py → exclude_mask.py → common.py. The import is
174+
cached by Python after the first call, so there's no performance penalty.
175+
176+
**Usage pattern:**
177+
178+
Typical wrapper classes use this three-line pattern:
179+
180+
>>> class MyWrapper(MyDPModel, torch.nn.Module):
181+
... def __setattr__(self, name, value):
182+
... handled, value = dpmodel_setattr(self, name, value)
183+
... if not handled:
184+
... super().__setattr__(name, value)
185+
186+
Examples
187+
--------
188+
>>> # Case 1: numpy array → buffer
189+
>>> obj.davg = np.array([1.0, 2.0]) # becomes torch.Tensor buffer
190+
>>>
191+
>>> # Case 2: clear buffer
192+
>>> obj.davg = None # sets buffer to None
193+
>>>
194+
>>> # Case 3: dpmodel object → pt_expt module
195+
>>> obj.emask = AtomExcludeMaskDP(...) # becomes AtomExcludeMask module
196+
"""
197+
from deepmd.pt_expt.utils import env # deferred - avoids circular import
198+
199+
# numpy array → torch buffer
200+
if isinstance(value, np.ndarray) and "_buffers" in obj.__dict__:
201+
tensor = torch.as_tensor(value, device=env.DEVICE)
202+
if name in obj._buffers:
203+
obj._buffers[name] = tensor
204+
return True, tensor
205+
obj.register_buffer(name, tensor)
206+
return True, tensor
207+
208+
# clear an existing buffer to None
209+
if value is None and "_buffers" in obj.__dict__ and name in obj._buffers:
210+
obj._buffers[name] = None
211+
return True, None
212+
213+
# dpmodel object → pt_expt module
214+
if "_modules" in obj.__dict__:
215+
converted = try_convert_module(value)
216+
if converted is not None:
217+
return False, converted
218+
219+
return False, value
220+
221+
222+
# ---------------------------------------------------------------------------
223+
# Utility
224+
# ---------------------------------------------------------------------------
15225
@overload
16226
def to_torch_array(array: np.ndarray) -> torch.Tensor: ...
17227

@@ -25,7 +235,42 @@ def to_torch_array(array: torch.Tensor) -> torch.Tensor: ...
25235

26236

27237
def to_torch_array(array: Any) -> torch.Tensor | None:
28-
"""Convert input to a torch tensor on the pt-expt device."""
238+
"""Convert input to a torch tensor on the pt_expt device.
239+
240+
This utility function handles conversion from various array-like types (numpy
241+
arrays, torch tensors on different devices, etc.) to torch tensors on the
242+
pt_expt backend's configured device.
243+
244+
Parameters
245+
----------
246+
array : Any
247+
The input to convert. Can be:
248+
- None (returns None)
249+
- torch.Tensor (moves to pt_expt device)
250+
- numpy array or array-like (converts to torch.Tensor on pt_expt device)
251+
252+
Returns
253+
-------
254+
torch.Tensor or None
255+
The input as a torch tensor on the pt_expt device (env.DEVICE), or None
256+
if the input was None.
257+
258+
Notes
259+
-----
260+
This function uses the same deferred import pattern as dpmodel_setattr to
261+
avoid circular dependencies. The env module determines the target device
262+
(typically CPU for pt_expt).
263+
264+
Examples
265+
--------
266+
>>> import numpy as np
267+
>>> arr = np.array([1.0, 2.0, 3.0])
268+
>>> tensor = to_torch_array(arr)
269+
>>> tensor.device
270+
device(type='cpu') # or whatever env.DEVICE is set to
271+
"""
272+
from deepmd.pt_expt.utils import env # deferred - avoids circular import
273+
29274
if array is None:
30275
return None
31276
if torch.is_tensor(array):

deepmd/pt_expt/descriptor/se_e2_a.py

Lines changed: 6 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,12 @@
66
import torch
77

88
from deepmd.dpmodel.descriptor.se_e2_a import DescrptSeAArrayAPI as DescrptSeADP
9+
from deepmd.pt_expt.common import (
10+
dpmodel_setattr,
11+
)
912
from deepmd.pt_expt.descriptor.base_descriptor import (
1013
BaseDescriptor,
1114
)
12-
from deepmd.pt_expt.utils import (
13-
env,
14-
)
15-
from deepmd.pt_expt.utils.exclude_mask import (
16-
PairExcludeMask,
17-
)
18-
from deepmd.pt_expt.utils.network import (
19-
NetworkCollection,
20-
)
2115

2216

2317
@BaseDescriptor.register("se_e2_a_expt")
@@ -32,30 +26,9 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any:
3226
return torch.nn.Module.__call__(self, *args, **kwargs)
3327

3428
def __setattr__(self, name: str, value: Any) -> None:
35-
if name in {"davg", "dstd"} and "_buffers" in self.__dict__:
36-
tensor = (
37-
None if value is None else torch.as_tensor(value, device=env.DEVICE)
38-
)
39-
if name in self._buffers:
40-
self._buffers[name] = tensor
41-
return
42-
# Register on first assignment so buffers are in state_dict and moved by .to().
43-
self.register_buffer(name, tensor)
44-
return
45-
if name == "embeddings" and "_modules" in self.__dict__:
46-
if value is not None and not isinstance(value, torch.nn.Module):
47-
if hasattr(value, "serialize"):
48-
value = NetworkCollection.deserialize(value.serialize())
49-
elif isinstance(value, dict):
50-
value = NetworkCollection.deserialize(value)
51-
return super().__setattr__(name, value)
52-
if name == "emask" and "_modules" in self.__dict__:
53-
if value is not None and not isinstance(value, torch.nn.Module):
54-
value = PairExcludeMask(
55-
self.ntypes, exclude_types=list(value.get_exclude_types())
56-
)
57-
return super().__setattr__(name, value)
58-
return super().__setattr__(name, value)
29+
handled, value = dpmodel_setattr(self, name, value)
30+
if not handled:
31+
super().__setattr__(name, value)
5932

6033
def forward(
6134
self,

deepmd/pt_expt/descriptor/se_r.py

Lines changed: 6 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,12 @@
66
import torch
77

88
from deepmd.dpmodel.descriptor.se_r import DescrptSeR as DescrptSeRDP
9+
from deepmd.pt_expt.common import (
10+
dpmodel_setattr,
11+
)
912
from deepmd.pt_expt.descriptor.base_descriptor import (
1013
BaseDescriptor,
1114
)
12-
from deepmd.pt_expt.utils import (
13-
env,
14-
)
15-
from deepmd.pt_expt.utils.exclude_mask import (
16-
PairExcludeMask,
17-
)
18-
from deepmd.pt_expt.utils.network import (
19-
NetworkCollection,
20-
)
2115

2216

2317
@BaseDescriptor.register("se_e2_r_expt")
@@ -32,30 +26,9 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any:
3226
return torch.nn.Module.__call__(self, *args, **kwargs)
3327

3428
def __setattr__(self, name: str, value: Any) -> None:
35-
if name in {"davg", "dstd"} and "_buffers" in self.__dict__:
36-
tensor = (
37-
None if value is None else torch.as_tensor(value, device=env.DEVICE)
38-
)
39-
if name in self._buffers:
40-
self._buffers[name] = tensor
41-
return
42-
# Register on first assignment so buffers are in state_dict and moved by .to().
43-
self.register_buffer(name, tensor)
44-
return
45-
if name == "embeddings" and "_modules" in self.__dict__:
46-
if value is not None and not isinstance(value, torch.nn.Module):
47-
if hasattr(value, "serialize"):
48-
value = NetworkCollection.deserialize(value.serialize())
49-
elif isinstance(value, dict):
50-
value = NetworkCollection.deserialize(value)
51-
return super().__setattr__(name, value)
52-
if name == "emask" and "_modules" in self.__dict__:
53-
if value is not None and not isinstance(value, torch.nn.Module):
54-
value = PairExcludeMask(
55-
self.ntypes, exclude_types=list(value.get_exclude_types())
56-
)
57-
return super().__setattr__(name, value)
58-
return super().__setattr__(name, value)
29+
handled, value = dpmodel_setattr(self, name, value)
30+
if not handled:
31+
super().__setattr__(name, value)
5932

6033
def forward(
6134
self,

deepmd/pt_expt/utils/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,12 @@
44
AtomExcludeMask,
55
PairExcludeMask,
66
)
7+
from .network import (
8+
NetworkCollection,
9+
)
710

811
__all__ = [
912
"AtomExcludeMask",
13+
"NetworkCollection",
1014
"PairExcludeMask",
1115
]

0 commit comments

Comments
 (0)