11# SPDX-License-Identifier: LGPL-3.0-or-later
2+ """Common utilities for the pt_expt backend.
3+
4+ This module provides the core infrastructure for automatically wrapping dpmodel
5+ classes (array_api_compat-based) as PyTorch modules. The key insight is to
6+ detect attributes by their **value type** rather than by hard-coded names:
7+
8+ - numpy arrays → torch buffers (persistent state like statistics, masks)
9+ - dpmodel objects → pt_expt torch.nn.Module wrappers (via registry lookup)
10+ - None values → clear existing buffers
11+
12+ This eliminates the need to manually enumerate attribute names in each wrapper's
13+ __setattr__ method, making the codebase more maintainable when dpmodel adds
14+ new attributes.
15+ """
16+
17+ from collections .abc import (
18+ Callable ,
19+ )
220from typing import (
321 Any ,
422 overload ,
725import numpy as np
826import torch
927
10- from deepmd .pt_expt . utils import (
11- env ,
28+ from deepmd .dpmodel . common import (
29+ NativeOP ,
1230)
1331
32+ # ---------------------------------------------------------------------------
33+ # dpmodel → pt_expt converter registry
34+ # ---------------------------------------------------------------------------
35+ _DPMODEL_TO_PT_EXPT : dict [type [NativeOP ], Callable [[NativeOP ], torch .nn .Module ]] = {}
36+ """Registry mapping dpmodel classes to their pt_expt converter functions.
37+
38+ This registry is populated at module import time via `register_dpmodel_mapping`
39+ calls in each pt_expt wrapper module (e.g., exclude_mask.py, network.py). When
40+ dpmodel_setattr encounters a dpmodel object, it looks up the object's type in
41+ this registry to find the appropriate converter.
42+
43+ Examples of registered mappings:
44+ - AtomExcludeMaskDP → lambda v: AtomExcludeMask(v.ntypes, exclude_types=...)
45+ - NetworkCollectionDP → lambda v: NetworkCollection.deserialize(v.serialize())
46+ """
47+
48+
49+ def register_dpmodel_mapping (
50+ dpmodel_cls : type [NativeOP ], converter : Callable [[NativeOP ], torch .nn .Module ]
51+ ) -> None :
52+ """Register a converter that turns a dpmodel instance into a pt_expt Module.
53+
54+ This function is called at module import time by each pt_expt wrapper to
55+ register how dpmodel objects should be converted when they're assigned as
56+ attributes. The converter is a callable that takes a dpmodel instance and
57+ returns the corresponding pt_expt torch.nn.Module wrapper.
58+
59+ Parameters
60+ ----------
61+ dpmodel_cls : type[NativeOP]
62+ The dpmodel class to register (e.g., AtomExcludeMaskDP, NetworkCollectionDP).
63+ This is the key used for lookup in dpmodel_setattr.
64+ converter : Callable[[NativeOP], torch.nn.Module]
65+ A callable that converts a dpmodel instance to a pt_expt module.
66+ Common patterns:
67+ - Reconstruct from constructor args: lambda v: PtExptClass(v.ntypes, ...)
68+ - Round-trip via serialization: lambda v: PtExptClass.deserialize(v.serialize())
69+
70+ Notes
71+ -----
72+ This function must be called AFTER the pt_expt wrapper class is defined but
73+ BEFORE dpmodel_setattr might encounter instances of dpmodel_cls. In practice,
74+ this means calling it immediately after the wrapper class definition at module
75+ import time.
76+
77+ Examples
78+ --------
79+ >>> register_dpmodel_mapping(
80+ ... AtomExcludeMaskDP,
81+ ... lambda v: AtomExcludeMask(
82+ ... v.ntypes, exclude_types=list(v.get_exclude_types())
83+ ... ),
84+ ... )
85+ """
86+ _DPMODEL_TO_PT_EXPT [dpmodel_cls ] = converter
87+
88+
89+ def try_convert_module (value : Any ) -> torch .nn .Module | None :
90+ """Convert a dpmodel object to its pt_expt wrapper if a converter is registered.
91+
92+ This function looks up the exact type of *value* in the _DPMODEL_TO_PT_EXPT
93+ registry. If a converter is found, it invokes it to produce a torch.nn.Module
94+ wrapper; otherwise it returns None.
95+
96+ Parameters
97+ ----------
98+ value : Any
99+ The value to potentially convert. Typically a dpmodel object like
100+ AtomExcludeMaskDP or NetworkCollectionDP.
101+
102+ Returns
103+ -------
104+ torch.nn.Module or None
105+ The converted pt_expt module if a converter is registered for value's
106+ type, otherwise None.
107+
108+ Notes
109+ -----
110+ This function uses exact type matching (not isinstance checks) to ensure
111+ predictable behavior. Each dpmodel class must be explicitly registered via
112+ register_dpmodel_mapping.
113+
114+ The function is called by dpmodel_setattr when it encounters an object that
115+ might be a dpmodel instance. If conversion succeeds, the caller should use
116+ the converted module instead of the original value.
117+ """
118+ converter = _DPMODEL_TO_PT_EXPT .get (type (value ))
119+ if converter is not None :
120+ return converter (value )
121+ return None
122+
123+
124+ def dpmodel_setattr (obj : torch .nn .Module , name : str , value : Any ) -> tuple [bool , Any ]:
125+ """Common __setattr__ logic for pt_expt wrappers around dpmodel classes.
126+
127+ This function implements automatic attribute detection by value type, eliminating
128+ the need to hard-code attribute names in each wrapper's __setattr__ method. It
129+ handles three cases:
130+
131+ 1. **numpy arrays → torch buffers**: Persistent state like statistics (davg, dstd)
132+ or masks that should be saved in state_dict and moved with .to(device).
133+ 2. **None values → clear buffers**: Setting an existing buffer to None.
134+ 3. **dpmodel objects → pt_expt modules**: Nested dpmodel objects like
135+ AtomExcludeMaskDP or NetworkCollectionDP are converted to their pt_expt
136+ wrappers via the registry.
137+
138+ Parameters
139+ ----------
140+ obj : torch.nn.Module
141+ The pt_expt wrapper object whose attribute is being set. Must be a
142+ torch.nn.Module (caller should verify this).
143+ name : str
144+ The attribute name being set.
145+ value : Any
146+ The value being assigned. This function inspects the type to determine
147+ how to handle it.
14148
149+ Returns
150+ -------
151+ handled : bool
152+ True if the attribute has been fully set (caller should NOT call
153+ super().__setattr__). False if the caller should forward the (possibly
154+ converted) value to super().__setattr__(name, value).
155+ value : Any
156+ The value to use. May be converted (e.g., dpmodel object → pt_expt module)
157+ or unchanged (e.g., scalar, list, or unregistered object).
158+
159+ Notes
160+ -----
161+ **Why this design is safe:**
162+
163+ - In dpmodel, all persistent arrays use `self.xxx = np.array(...)`. Scalars
164+ use `.item()`, lists use `.tolist()`. So `isinstance(value, np.ndarray)`
165+ reliably identifies buffer-worthy attributes.
166+ - torch.Tensor values assigned to existing buffers fall through to
167+ torch.nn.Module.__setattr__, which correctly updates them.
168+ - dpmodel objects are identified by registry lookup (exact type match), so
169+ only explicitly registered types are converted.
170+ - The function checks `"_buffers" in obj.__dict__` to ensure the object has
171+ been initialized as a torch.nn.Module before attempting buffer operations.
172+
173+ **Circular import resolution:**
174+
175+ The function uses a deferred import `from deepmd.pt_expt.utils import env`
176+ inside the function body. This breaks the circular dependency chain:
177+ common.py → utils/__init__.py → exclude_mask.py → common.py. The import is
178+ cached by Python after the first call, so there's no performance penalty.
179+
180+ **Usage pattern:**
181+
182+ Typical wrapper classes use this three-line pattern:
183+
184+ >>> class MyWrapper(MyDPModel, torch.nn.Module):
185+ ... def __setattr__(self, name, value):
186+ ... handled, value = dpmodel_setattr(self, name, value)
187+ ... if not handled:
188+ ... super().__setattr__(name, value)
189+
190+ Examples
191+ --------
192+ >>> # Case 1: numpy array → buffer
193+ >>> obj.davg = np.array([1.0, 2.0]) # becomes torch.Tensor buffer
194+ >>>
195+ >>> # Case 2: clear buffer
196+ >>> obj.davg = None # sets buffer to None
197+ >>>
198+ >>> # Case 3: dpmodel object → pt_expt module
199+ >>> obj.emask = AtomExcludeMaskDP(...) # becomes AtomExcludeMask module
200+ """
201+ from deepmd .pt_expt .utils import env # deferred - avoids circular import
202+
203+ # numpy array → torch buffer
204+ if isinstance (value , np .ndarray ) and "_buffers" in obj .__dict__ :
205+ tensor = torch .as_tensor (value , device = env .DEVICE )
206+ if name in obj ._buffers :
207+ obj ._buffers [name ] = tensor
208+ return True , tensor
209+ obj .register_buffer (name , tensor )
210+ return True , tensor
211+
212+ # clear an existing buffer to None
213+ if value is None and "_buffers" in obj .__dict__ and name in obj ._buffers :
214+ obj ._buffers [name ] = None
215+ return True , None
216+
217+ # dpmodel object → pt_expt module
218+ if "_modules" in obj .__dict__ :
219+ # Try to convert dpmodel objects that aren't already torch.nn.Modules
220+ if not isinstance (value , torch .nn .Module ):
221+ converted = try_convert_module (value )
222+ if converted is not None :
223+ return False , converted
224+ # If this is a NativeOP that should have been registered but wasn't, raise error
225+ if isinstance (value , NativeOP ):
226+ raise TypeError (
227+ f"Attempted to assign a dpmodel object of type { type (value ).__name__ } "
228+ f"but no converter is registered. Please call register_dpmodel_mapping "
229+ f"for this type. If this object doesn't need conversion, register it "
230+ f"with an identity converter: lambda v: v"
231+ )
232+
233+ return False , value
234+
235+
236+ # ---------------------------------------------------------------------------
237+ # Utility
238+ # ---------------------------------------------------------------------------
15239@overload
16240def to_torch_array (array : np .ndarray ) -> torch .Tensor : ...
17241
@@ -25,9 +249,59 @@ def to_torch_array(array: torch.Tensor) -> torch.Tensor: ...
25249
26250
27251def to_torch_array (array : Any ) -> torch .Tensor | None :
28- """Convert input to a torch tensor on the pt-expt device."""
252+ """Convert input to a torch tensor on the pt_expt device.
253+
254+ This utility function handles conversion from various array-like types (numpy
255+ arrays, torch tensors on different devices, etc.) to torch tensors on the
256+ pt_expt backend's configured device.
257+
258+ Parameters
259+ ----------
260+ array : Any
261+ The input to convert. Can be:
262+ - None (returns None)
263+ - torch.Tensor (moves to pt_expt device)
264+ - numpy array or array-like (converts to torch.Tensor on pt_expt device)
265+
266+ Returns
267+ -------
268+ torch.Tensor or None
269+ The input as a torch tensor on the pt_expt device (env.DEVICE), or None
270+ if the input was None.
271+
272+ Notes
273+ -----
274+ This function uses the same deferred import pattern as dpmodel_setattr to
275+ avoid circular dependencies. The env module determines the target device
276+ (typically CPU for pt_expt).
277+
278+ Examples
279+ --------
280+ >>> import numpy as np
281+ >>> arr = np.array([1.0, 2.0, 3.0])
282+ >>> tensor = to_torch_array(arr)
283+ >>> tensor.device
284+ device(type='cpu') # or whatever env.DEVICE is set to
285+ """
286+ from deepmd .pt_expt .utils import env # deferred - avoids circular import
287+
29288 if array is None :
30289 return None
31290 if torch .is_tensor (array ):
32291 return array .to (device = env .DEVICE )
33292 return torch .as_tensor (array , device = env .DEVICE )
293+
294+
295+ # Import utils to trigger dpmodel→pt_expt converter registrations
296+ # This must happen after the functions above are defined to avoid circular imports
297+ def _ensure_registrations () -> None :
298+ """Import pt_expt.utils modules to register converters.
299+
300+ This function is called on module import to ensure all dpmodel→pt_expt
301+ converters are registered before any descriptors/fittings try to use them.
302+ """
303+ # Import triggers registration of NetworkCollection, ExcludeMask, EnvMat
304+ from deepmd .pt_expt import utils # noqa: F401
305+
306+
307+ _ensure_registrations ()
0 commit comments