Skip to content

Commit 55e094e

Browse files
author
Han Wang
committed
raise error
1 parent ef84c6c commit 55e094e

2 files changed

Lines changed: 39 additions & 6 deletions

File tree

deepmd/pt_expt/common.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -216,12 +216,19 @@ def dpmodel_setattr(obj: torch.nn.Module, name: str, value: Any) -> tuple[bool,
216216

217217
# dpmodel object → pt_expt module
218218
if "_modules" in obj.__dict__:
219-
converted = try_convert_module(value)
220-
if converted is not None:
221-
return False, converted
222-
# Note: Some NativeOP objects (like EnvMat) don't need conversion and can
223-
# be used directly. If a NativeOP truly needs conversion but isn't registered,
224-
# it will fail at runtime when the object is actually used.
219+
# Try to convert dpmodel objects that aren't already torch.nn.Modules
220+
if not isinstance(value, torch.nn.Module):
221+
converted = try_convert_module(value)
222+
if converted is not None:
223+
return False, converted
224+
# If this is a NativeOP that should have been registered but wasn't, raise error
225+
if isinstance(value, NativeOP):
226+
raise TypeError(
227+
f"Attempted to assign a dpmodel object of type {type(value).__name__} "
228+
f"but no converter is registered. Please call register_dpmodel_mapping "
229+
f"for this type. If this object doesn't need conversion, register it "
230+
f"with an identity converter: lambda v: v"
231+
)
225232

226233
return False, value
227234

@@ -283,3 +290,18 @@ def to_torch_array(array: Any) -> torch.Tensor | None:
283290
if torch.is_tensor(array):
284291
return array.to(device=env.DEVICE)
285292
return torch.as_tensor(array, device=env.DEVICE)
293+
294+
295+
# Import utils to trigger dpmodel→pt_expt converter registrations
296+
# This must happen after the functions above are defined to avoid circular imports
297+
def _ensure_registrations() -> None:
298+
"""Import pt_expt.utils modules to register converters.
299+
300+
This function is called on module import to ensure all dpmodel→pt_expt
301+
converters are registered before any descriptors/fittings try to use them.
302+
"""
303+
# Import triggers registration of NetworkCollection, ExcludeMask, EnvMat
304+
from deepmd.pt_expt import utils # noqa: F401
305+
306+
307+
_ensure_registrations()

deepmd/pt_expt/utils/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22

3+
from deepmd.dpmodel.utils.env_mat import (
4+
EnvMat,
5+
)
6+
from deepmd.pt_expt.common import (
7+
register_dpmodel_mapping,
8+
)
9+
310
from .exclude_mask import (
411
AtomExcludeMask,
512
PairExcludeMask,
@@ -8,6 +15,10 @@
815
NetworkCollection,
916
)
1017

18+
# Register EnvMat with identity converter - it doesn't need wrapping
19+
# as it's a stateless utility class
20+
register_dpmodel_mapping(EnvMat, lambda v: v)
21+
1122
__all__ = [
1223
"AtomExcludeMask",
1324
"NetworkCollection",

0 commit comments

Comments
 (0)