2525import numpy as np
2626import torch
2727
28+ from deepmd .dpmodel .common import (
29+ NativeOP ,
30+ )
31+
2832# ---------------------------------------------------------------------------
2933# dpmodel → pt_expt converter registry
3034# ---------------------------------------------------------------------------
31- _DPMODEL_TO_PT_EXPT : dict [type , Callable [[Any ], torch .nn .Module ]] = {}
35+ _DPMODEL_TO_PT_EXPT : dict [type [ NativeOP ] , Callable [[NativeOP ], torch .nn .Module ]] = {}
3236"""Registry mapping dpmodel classes to their pt_expt converter functions.
3337
3438This registry is populated at module import time via `register_dpmodel_mapping`
4347
4448
4549def register_dpmodel_mapping (
46- dpmodel_cls : type , converter : Callable [[Any ], torch .nn .Module ]
50+ dpmodel_cls : type [ NativeOP ] , converter : Callable [[NativeOP ], torch .nn .Module ]
4751) -> None :
4852 """Register a converter that turns a dpmodel instance into a pt_expt Module.
4953
@@ -54,10 +58,10 @@ def register_dpmodel_mapping(
5458
5559 Parameters
5660 ----------
57- dpmodel_cls : type
61+ dpmodel_cls : type[NativeOP]
5862 The dpmodel class to register (e.g., AtomExcludeMaskDP, NetworkCollectionDP).
5963 This is the key used for lookup in dpmodel_setattr.
60- converter : Callable[[Any ], torch.nn.Module]
64+ converter : Callable[[NativeOP ], torch.nn.Module]
6165 A callable that converts a dpmodel instance to a pt_expt module.
6266 Common patterns:
6367 - Reconstruct from constructor args: lambda v: PtExptClass(v.ntypes, ...)
@@ -212,9 +216,17 @@ def dpmodel_setattr(obj: torch.nn.Module, name: str, value: Any) -> tuple[bool,
212216
213217 # dpmodel object → pt_expt module
214218 if "_modules" in obj .__dict__ :
215- converted = try_convert_module (value )
216- if converted is not None :
217- return False , converted
219+ # Check if this is a NativeOP that needs conversion
220+ if isinstance (value , NativeOP ) and not isinstance (value , torch .nn .Module ):
221+ converted = try_convert_module (value )
222+ if converted is not None :
223+ return False , converted
224+ # If it's a NativeOP but not registered, this is likely a bug
225+ raise TypeError (
226+ f"Attempted to assign a dpmodel object of type { type (value ).__name__ } "
227+ f"but no converter is registered. Please call register_dpmodel_mapping "
228+ f"for this type."
229+ )
218230
219231 return False , value
220232
0 commit comments