11# SPDX-License-Identifier: LGPL-3.0-or-later
2+ """Common utilities for the pt_expt backend.
3+
4+ This module provides the core infrastructure for automatically wrapping dpmodel
5+ classes (array_api_compat-based) as PyTorch modules. The key insight is to
6+ detect attributes by their **value type** rather than by hard-coded names:
7+
8+ - numpy arrays → torch buffers (persistent state like statistics, masks)
9+ - dpmodel objects → pt_expt torch.nn.Module wrappers (via registry lookup)
10+ - None values → clear existing buffers
11+
12+ This eliminates the need to manually enumerate attribute names in each wrapper's
13+ __setattr__ method, making the codebase more maintainable when dpmodel adds
14+ new attributes.
15+ """
16+
17+ from collections .abc import (
18+ Callable ,
19+ )
220from typing import (
321 Any ,
422 overload ,
725import numpy as np
826import torch
927
10- from deepmd .pt_expt .utils import (
11- env ,
12- )
28+ # ---------------------------------------------------------------------------
29+ # dpmodel → pt_expt converter registry
30+ # ---------------------------------------------------------------------------
31+ _DPMODEL_TO_PT_EXPT : dict [type , Callable [[Any ], torch .nn .Module ]] = {}
32+ """Registry mapping dpmodel classes to their pt_expt converter functions.
33+
34+ This registry is populated at module import time via `register_dpmodel_mapping`
35+ calls in each pt_expt wrapper module (e.g., exclude_mask.py, network.py). When
36+ dpmodel_setattr encounters a dpmodel object, it looks up the object's type in
37+ this registry to find the appropriate converter.
38+
39+ Examples of registered mappings:
40+ - AtomExcludeMaskDP → lambda v: AtomExcludeMask(v.ntypes, exclude_types=...)
41+ - NetworkCollectionDP → lambda v: NetworkCollection.deserialize(v.serialize())
42+ """
43+
44+
45+ def register_dpmodel_mapping (
46+ dpmodel_cls : type , converter : Callable [[Any ], torch .nn .Module ]
47+ ) -> None :
48+ """Register a converter that turns a dpmodel instance into a pt_expt Module.
49+
50+ This function is called at module import time by each pt_expt wrapper to
51+ register how dpmodel objects should be converted when they're assigned as
52+ attributes. The converter is a callable that takes a dpmodel instance and
53+ returns the corresponding pt_expt torch.nn.Module wrapper.
54+
55+ Parameters
56+ ----------
57+ dpmodel_cls : type
58+ The dpmodel class to register (e.g., AtomExcludeMaskDP, NetworkCollectionDP).
59+ This is the key used for lookup in dpmodel_setattr.
60+ converter : Callable[[Any], torch.nn.Module]
61+ A callable that converts a dpmodel instance to a pt_expt module.
62+ Common patterns:
63+ - Reconstruct from constructor args: lambda v: PtExptClass(v.ntypes, ...)
64+ - Round-trip via serialization: lambda v: PtExptClass.deserialize(v.serialize())
65+
66+ Notes
67+ -----
68+ This function must be called AFTER the pt_expt wrapper class is defined but
69+ BEFORE dpmodel_setattr might encounter instances of dpmodel_cls. In practice,
70+ this means calling it immediately after the wrapper class definition at module
71+ import time.
72+
73+ Examples
74+ --------
75+ >>> register_dpmodel_mapping(
76+ ... AtomExcludeMaskDP,
77+ ... lambda v: AtomExcludeMask(
78+ ... v.ntypes, exclude_types=list(v.get_exclude_types())
79+ ... ),
80+ ... )
81+ """
82+ _DPMODEL_TO_PT_EXPT [dpmodel_cls ] = converter
83+
84+
85+ def try_convert_module (value : Any ) -> torch .nn .Module | None :
86+ """Convert a dpmodel object to its pt_expt wrapper if a converter is registered.
87+
88+ This function looks up the exact type of *value* in the _DPMODEL_TO_PT_EXPT
89+ registry. If a converter is found, it invokes it to produce a torch.nn.Module
90+ wrapper; otherwise it returns None.
91+
92+ Parameters
93+ ----------
94+ value : Any
95+ The value to potentially convert. Typically a dpmodel object like
96+ AtomExcludeMaskDP or NetworkCollectionDP.
97+
98+ Returns
99+ -------
100+ torch.nn.Module or None
101+ The converted pt_expt module if a converter is registered for value's
102+ type, otherwise None.
103+
104+ Notes
105+ -----
106+ This function uses exact type matching (not isinstance checks) to ensure
107+ predictable behavior. Each dpmodel class must be explicitly registered via
108+ register_dpmodel_mapping.
109+
110+ The function is called by dpmodel_setattr when it encounters an object that
111+ might be a dpmodel instance. If conversion succeeds, the caller should use
112+ the converted module instead of the original value.
113+ """
114+ converter = _DPMODEL_TO_PT_EXPT .get (type (value ))
115+ if converter is not None :
116+ return converter (value )
117+ return None
118+
119+
120+ def dpmodel_setattr (obj : torch .nn .Module , name : str , value : Any ) -> tuple [bool , Any ]:
121+ """Common __setattr__ logic for pt_expt wrappers around dpmodel classes.
13122
123+ This function implements automatic attribute detection by value type, eliminating
124+ the need to hard-code attribute names in each wrapper's __setattr__ method. It
125+ handles three cases:
14126
127+ 1. **numpy arrays → torch buffers**: Persistent state like statistics (davg, dstd)
128+ or masks that should be saved in state_dict and moved with .to(device).
129+ 2. **None values → clear buffers**: Setting an existing buffer to None.
130+ 3. **dpmodel objects → pt_expt modules**: Nested dpmodel objects like
131+ AtomExcludeMaskDP or NetworkCollectionDP are converted to their pt_expt
132+ wrappers via the registry.
133+
134+ Parameters
135+ ----------
136+ obj : torch.nn.Module
137+ The pt_expt wrapper object whose attribute is being set. Must be a
138+ torch.nn.Module (caller should verify this).
139+ name : str
140+ The attribute name being set.
141+ value : Any
142+ The value being assigned. This function inspects the type to determine
143+ how to handle it.
144+
145+ Returns
146+ -------
147+ handled : bool
148+ True if the attribute has been fully set (caller should NOT call
149+ super().__setattr__). False if the caller should forward the (possibly
150+ converted) value to super().__setattr__(name, value).
151+ value : Any
152+ The value to use. May be converted (e.g., dpmodel object → pt_expt module)
153+ or unchanged (e.g., scalar, list, or unregistered object).
154+
155+ Notes
156+ -----
157+ **Why this design is safe:**
158+
159+ - In dpmodel, all persistent arrays use `self.xxx = np.array(...)`. Scalars
160+ use `.item()`, lists use `.tolist()`. So `isinstance(value, np.ndarray)`
161+ reliably identifies buffer-worthy attributes.
162+ - torch.Tensor values assigned to existing buffers fall through to
163+ torch.nn.Module.__setattr__, which correctly updates them.
164+ - dpmodel objects are identified by registry lookup (exact type match), so
165+ only explicitly registered types are converted.
166+ - The function checks `"_buffers" in obj.__dict__` to ensure the object has
167+ been initialized as a torch.nn.Module before attempting buffer operations.
168+
169+ **Circular import resolution:**
170+
171+ The function uses a deferred import `from deepmd.pt_expt.utils import env`
172+ inside the function body. This breaks the circular dependency chain:
173+ common.py → utils/__init__.py → exclude_mask.py → common.py. The import is
174+ cached by Python after the first call, so there's no performance penalty.
175+
176+ **Usage pattern:**
177+
178+ Typical wrapper classes use this three-line pattern:
179+
180+ >>> class MyWrapper(MyDPModel, torch.nn.Module):
181+ ... def __setattr__(self, name, value):
182+ ... handled, value = dpmodel_setattr(self, name, value)
183+ ... if not handled:
184+ ... super().__setattr__(name, value)
185+
186+ Examples
187+ --------
188+ >>> # Case 1: numpy array → buffer
189+ >>> obj.davg = np.array([1.0, 2.0]) # becomes torch.Tensor buffer
190+ >>>
191+ >>> # Case 2: clear buffer
192+ >>> obj.davg = None # sets buffer to None
193+ >>>
194+ >>> # Case 3: dpmodel object → pt_expt module
195+ >>> obj.emask = AtomExcludeMaskDP(...) # becomes AtomExcludeMask module
196+ """
197+ from deepmd .pt_expt .utils import env # deferred - avoids circular import
198+
199+ # numpy array → torch buffer
200+ if isinstance (value , np .ndarray ) and "_buffers" in obj .__dict__ :
201+ tensor = torch .as_tensor (value , device = env .DEVICE )
202+ if name in obj ._buffers :
203+ obj ._buffers [name ] = tensor
204+ return True , tensor
205+ obj .register_buffer (name , tensor )
206+ return True , tensor
207+
208+ # clear an existing buffer to None
209+ if value is None and "_buffers" in obj .__dict__ and name in obj ._buffers :
210+ obj ._buffers [name ] = None
211+ return True , None
212+
213+ # dpmodel object → pt_expt module
214+ if "_modules" in obj .__dict__ :
215+ converted = try_convert_module (value )
216+ if converted is not None :
217+ return False , converted
218+
219+ return False , value
220+
221+
222+ # ---------------------------------------------------------------------------
223+ # Utility
224+ # ---------------------------------------------------------------------------
15225@overload
16226def to_torch_array (array : np .ndarray ) -> torch .Tensor : ...
17227
@@ -25,7 +235,42 @@ def to_torch_array(array: torch.Tensor) -> torch.Tensor: ...
25235
26236
27237def to_torch_array (array : Any ) -> torch .Tensor | None :
28- """Convert input to a torch tensor on the pt-expt device."""
238+ """Convert input to a torch tensor on the pt_expt device.
239+
240+ This utility function handles conversion from various array-like types (numpy
241+ arrays, torch tensors on different devices, etc.) to torch tensors on the
242+ pt_expt backend's configured device.
243+
244+ Parameters
245+ ----------
246+ array : Any
247+ The input to convert. Can be:
248+ - None (returns None)
249+ - torch.Tensor (moves to pt_expt device)
250+ - numpy array or array-like (converts to torch.Tensor on pt_expt device)
251+
252+ Returns
253+ -------
254+ torch.Tensor or None
255+ The input as a torch tensor on the pt_expt device (env.DEVICE), or None
256+ if the input was None.
257+
258+ Notes
259+ -----
260+ This function uses the same deferred import pattern as dpmodel_setattr to
261+ avoid circular dependencies. The env module determines the target device
262+ (typically CPU for pt_expt).
263+
264+ Examples
265+ --------
266+ >>> import numpy as np
267+ >>> arr = np.array([1.0, 2.0, 3.0])
268+ >>> tensor = to_torch_array(arr)
269+ >>> tensor.device
270+ device(type='cpu') # or whatever env.DEVICE is set to
271+ """
272+ from deepmd .pt_expt .utils import env # deferred - avoids circular import
273+
29274 if array is None :
30275 return None
31276 if torch .is_tensor (array ):
0 commit comments