2525import numpy as np
2626import torch
2727
28+ from deepmd .dpmodel .common import (
29+ NativeOP ,
30+ )
31+
2832# ---------------------------------------------------------------------------
2933# dpmodel → pt_expt converter registry
3034# ---------------------------------------------------------------------------
31- _DPMODEL_TO_PT_EXPT : dict [type , Callable [[Any ], torch .nn .Module ]] = {}
35+ _DPMODEL_TO_PT_EXPT : dict [type [ NativeOP ] , Callable [[NativeOP ], torch .nn .Module ]] = {}
3236"""Registry mapping dpmodel classes to their pt_expt converter functions.
3337
3438This registry is populated at module import time via `register_dpmodel_mapping`
4347
4448
4549def register_dpmodel_mapping (
46- dpmodel_cls : type , converter : Callable [[Any ], torch .nn .Module ]
50+ dpmodel_cls : type [ NativeOP ] , converter : Callable [[NativeOP ], torch .nn .Module ]
4751) -> None :
4852 """Register a converter that turns a dpmodel instance into a pt_expt Module.
4953
@@ -54,10 +58,10 @@ def register_dpmodel_mapping(
5458
5559 Parameters
5660 ----------
57- dpmodel_cls : type
61+ dpmodel_cls : type[NativeOP]
5862 The dpmodel class to register (e.g., AtomExcludeMaskDP, NetworkCollectionDP).
5963 This is the key used for lookup in dpmodel_setattr.
60- converter : Callable[[Any ], torch.nn.Module]
64+ converter : Callable[[NativeOP ], torch.nn.Module]
6165 A callable that converts a dpmodel instance to a pt_expt module.
6266 Common patterns:
6367 - Reconstruct from constructor args: lambda v: PtExptClass(v.ntypes, ...)
@@ -85,7 +89,7 @@ def register_dpmodel_mapping(
8589def try_convert_module (value : Any ) -> torch .nn .Module | None :
8690 """Convert a dpmodel object to its pt_expt wrapper if a converter is registered.
8791
88- This function looks up the type of *value* in the _DPMODEL_TO_PT_EXPT
92+ This function looks up the exact type of *value* in the _DPMODEL_TO_PT_EXPT
8993 registry. If a converter is found, it invokes it to produce a torch.nn.Module
9094 wrapper; otherwise it returns None.
9195
@@ -103,8 +107,9 @@ def try_convert_module(value: Any) -> torch.nn.Module | None:
103107
104108 Notes
105109 -----
106- This function uses exact type matching. Each dpmodel class must be explicitly
107- registered via register_dpmodel_mapping.
110+ This function uses exact type matching (not isinstance checks) to ensure
111+ predictable behavior. Each dpmodel class must be explicitly registered via
112+ register_dpmodel_mapping.
108113
109114 The function is called by dpmodel_setattr when it encounters an object that
110115 might be a dpmodel instance. If conversion succeeds, the caller should use
@@ -211,9 +216,19 @@ def dpmodel_setattr(obj: torch.nn.Module, name: str, value: Any) -> tuple[bool,
211216
212217 # dpmodel object → pt_expt module
213218 if "_modules" in obj .__dict__ :
214- converted = try_convert_module (value )
215- if converted is not None :
216- return False , converted
219+ # Try to convert dpmodel objects that aren't already torch.nn.Modules
220+ if not isinstance (value , torch .nn .Module ):
221+ converted = try_convert_module (value )
222+ if converted is not None :
223+ return False , converted
224+ # If this is a NativeOP that should have been registered but wasn't, raise error
225+ if isinstance (value , NativeOP ):
226+ raise TypeError (
227+ f"Attempted to assign a dpmodel object of type { type (value ).__name__ } "
228+ f"but no converter is registered. Please call register_dpmodel_mapping "
229+ f"for this type. If this object doesn't need conversion, register it "
230+ f"with an identity converter: lambda v: v"
231+ )
217232
218233 return False , value
219234
@@ -275,3 +290,18 @@ def to_torch_array(array: Any) -> torch.Tensor | None:
275290 if torch .is_tensor (array ):
276291 return array .to (device = env .DEVICE )
277292 return torch .as_tensor (array , device = env .DEVICE )
293+
294+
295+ # Import utils to trigger dpmodel→pt_expt converter registrations
296+ # This must happen after the functions above are defined to avoid circular imports
297+ def _ensure_registrations () -> None :
298+ """Import pt_expt.utils modules to register converters.
299+
300+ This function is called on module import to ensure all dpmodel→pt_expt
301+ converters are registered before any descriptors/fittings try to use them.
302+ """
303+ # Import triggers registration of NetworkCollection, ExcludeMask, EnvMat
304+ from deepmd .pt_expt import utils # noqa: F401
305+
306+
307+ _ensure_registrations ()
0 commit comments