Skip to content

Commit 87e9b9d

Browse files
author
Han Wang
committed
better type checking
1 parent de8f156 commit 87e9b9d

1 file changed

Lines changed: 19 additions & 7 deletions

File tree

deepmd/pt_expt/common.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,14 @@
2525
import numpy as np
2626
import torch
2727

28+
from deepmd.dpmodel.common import (
29+
NativeOP,
30+
)
31+
2832
# ---------------------------------------------------------------------------
2933
# dpmodel → pt_expt converter registry
3034
# ---------------------------------------------------------------------------
31-
_DPMODEL_TO_PT_EXPT: dict[type, Callable[[Any], torch.nn.Module]] = {}
35+
_DPMODEL_TO_PT_EXPT: dict[type[NativeOP], Callable[[NativeOP], torch.nn.Module]] = {}
3236
"""Registry mapping dpmodel classes to their pt_expt converter functions.
3337
3438
This registry is populated at module import time via `register_dpmodel_mapping`
@@ -43,7 +47,7 @@
4347

4448

4549
def register_dpmodel_mapping(
46-
dpmodel_cls: type, converter: Callable[[Any], torch.nn.Module]
50+
dpmodel_cls: type[NativeOP], converter: Callable[[NativeOP], torch.nn.Module]
4751
) -> None:
4852
"""Register a converter that turns a dpmodel instance into a pt_expt Module.
4953
@@ -54,10 +58,10 @@ def register_dpmodel_mapping(
5458
5559
Parameters
5660
----------
57-
dpmodel_cls : type
61+
dpmodel_cls : type[NativeOP]
5862
The dpmodel class to register (e.g., AtomExcludeMaskDP, NetworkCollectionDP).
5963
This is the key used for lookup in dpmodel_setattr.
60-
converter : Callable[[Any], torch.nn.Module]
64+
converter : Callable[[NativeOP], torch.nn.Module]
6165
A callable that converts a dpmodel instance to a pt_expt module.
6266
Common patterns:
6367
- Reconstruct from constructor args: lambda v: PtExptClass(v.ntypes, ...)
@@ -212,9 +216,17 @@ def dpmodel_setattr(obj: torch.nn.Module, name: str, value: Any) -> tuple[bool,
212216

213217
# dpmodel object → pt_expt module
214218
if "_modules" in obj.__dict__:
215-
converted = try_convert_module(value)
216-
if converted is not None:
217-
return False, converted
219+
# Check if this is a NativeOP that needs conversion
220+
if isinstance(value, NativeOP) and not isinstance(value, torch.nn.Module):
221+
converted = try_convert_module(value)
222+
if converted is not None:
223+
return False, converted
224+
# If it's a NativeOP but not registered, this is likely a bug
225+
raise TypeError(
226+
f"Attempted to assign a dpmodel object of type {type(value).__name__} "
227+
f"but no converter is registered. Please call register_dpmodel_mapping "
228+
f"for this type."
229+
)
218230

219231
return False, value
220232

0 commit comments

Comments
 (0)