Skip to content

Commit 7303412

Browse files
committed
refactor(pretrained): lazy load at backend boundary, eager deep_eval module
- keep lazy import in deepmd/backend/pretrained.py\n- keep deepmd/pretrained/deep_eval.py as regular (non-lazy) module\n- preserve deep eval delegations for descriptor/fitting-last-layer\n- simplify resolve_model_path and adjust tests for cached path behavior\n\nAuthored by OpenClaw (model: custom-chat-jinzhezeng-group/gpt-5.3-codex)
1 parent 27ac825 commit 7303412

3 files changed

Lines changed: 145 additions & 155 deletions

File tree

deepmd/backend/pretrained.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,6 @@
1010
from deepmd.backend.backend import (
1111
Backend,
1212
)
13-
from deepmd.pretrained.deep_eval import (
14-
get_pretrained_deep_eval_backend,
15-
)
1613

1714
if TYPE_CHECKING:
1815
from argparse import (
@@ -44,7 +41,11 @@ def entry_point_hook(self) -> Callable[["Namespace"], None]:
4441

4542
@property
4643
def deep_eval(self) -> type["DeepEvalBackend"]:
47-
return get_pretrained_deep_eval_backend()
44+
from deepmd.pretrained.deep_eval import (
45+
PretrainedDeepEvalBackend,
46+
)
47+
48+
return PretrainedDeepEvalBackend
4849

4950
@property
5051
def neighbor_stat(self) -> type["NeighborStat"]:

deepmd/pretrained/deep_eval.py

Lines changed: 135 additions & 150 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,10 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2-
"""Backend helper for `*.pretrained` model aliases."""
2+
"""DeepEval adapter for `*.pretrained` model aliases."""
33

44
from __future__ import (
55
annotations,
66
)
77

8-
from functools import (
9-
lru_cache,
10-
)
118
from pathlib import (
129
Path,
1310
)
@@ -16,18 +13,16 @@
1613
Any,
1714
)
1815

16+
from deepmd.infer.deep_eval import (
17+
DeepEval,
18+
DeepEvalBackend,
19+
)
1920
from deepmd.pretrained.download import (
2021
resolve_model_path,
2122
)
22-
2323
if TYPE_CHECKING:
2424
import numpy as np
2525

26-
from deepmd.infer.deep_eval import (
27-
DeepEval,
28-
DeepEvalBackend,
29-
)
30-
3126

3227
def parse_pretrained_alias(model_file: str) -> str:
3328
"""Extract model name from ``*.pretrained`` alias string."""
@@ -43,143 +38,133 @@ def parse_pretrained_alias(model_file: str) -> str:
4338
return model_name
4439

4540

46-
@lru_cache(maxsize=1)
47-
def get_pretrained_deep_eval_backend() -> type[DeepEvalBackend]:
48-
"""Build and cache the concrete DeepEval adapter lazily."""
49-
# Avoid circular import when deepmd backend entrypoints are loading.
50-
from deepmd.infer.deep_eval import (
51-
DeepEvalBackend,
52-
)
53-
54-
class PretrainedDeepEvalBackend(DeepEvalBackend):
55-
"""Resolve alias and delegate to backend selected by resolved model path."""
56-
57-
def __init__(
58-
self,
59-
model_file: str,
60-
output_def: object,
61-
*args: object,
62-
auto_batch_size: object = True,
63-
neighbor_list: object | None = None,
64-
**kwargs: object,
65-
) -> None:
66-
model_name = parse_pretrained_alias(model_file)
67-
resolved = str(resolve_model_path(model_name))
68-
69-
# DeepEvalBackend.__new__ dispatches by resolved suffix (.pt/.pb/.dp...)
70-
self._backend = DeepEvalBackend(
71-
resolved,
72-
output_def,
73-
*args,
74-
auto_batch_size=auto_batch_size,
75-
neighbor_list=neighbor_list,
76-
**kwargs,
77-
)
78-
79-
def eval(
80-
self,
81-
coords: np.ndarray,
82-
cells: np.ndarray | None,
83-
atom_types: np.ndarray,
84-
atomic: bool = False,
85-
fparam: np.ndarray | None = None,
86-
aparam: np.ndarray | None = None,
87-
**kwargs: Any,
88-
) -> dict[str, np.ndarray]:
89-
return self._backend.eval(
90-
coords,
91-
cells,
92-
atom_types,
93-
atomic,
94-
fparam=fparam,
95-
aparam=aparam,
96-
**kwargs,
97-
)
98-
99-
def eval_descriptor(
100-
self,
101-
coords: np.ndarray,
102-
cells: np.ndarray | None,
103-
atom_types: np.ndarray,
104-
fparam: np.ndarray | None = None,
105-
aparam: np.ndarray | None = None,
106-
efield: np.ndarray | None = None,
107-
mixed_type: bool = False,
108-
**kwargs: Any,
109-
) -> np.ndarray:
110-
return self._backend.eval_descriptor(
111-
coords,
112-
cells,
113-
atom_types,
114-
fparam=fparam,
115-
aparam=aparam,
116-
efield=efield,
117-
mixed_type=mixed_type,
118-
**kwargs,
119-
)
120-
121-
def eval_fitting_last_layer(
122-
self,
123-
coords: np.ndarray,
124-
cells: np.ndarray | None,
125-
atom_types: np.ndarray,
126-
fparam: np.ndarray | None = None,
127-
aparam: np.ndarray | None = None,
128-
**kwargs: Any,
129-
) -> np.ndarray:
130-
return self._backend.eval_fitting_last_layer(
131-
coords,
132-
cells,
133-
atom_types,
134-
fparam=fparam,
135-
aparam=aparam,
136-
**kwargs,
137-
)
138-
139-
def get_rcut(self) -> float:
140-
return self._backend.get_rcut()
141-
142-
def get_ntypes(self) -> int:
143-
return self._backend.get_ntypes()
144-
145-
def get_type_map(self) -> list[str]:
146-
return self._backend.get_type_map()
147-
148-
def get_dim_fparam(self) -> int:
149-
return self._backend.get_dim_fparam()
150-
151-
def has_default_fparam(self) -> bool:
152-
return self._backend.has_default_fparam()
153-
154-
def get_dim_aparam(self) -> int:
155-
return self._backend.get_dim_aparam()
156-
157-
@property
158-
def model_type(self) -> type[DeepEval]:
159-
return self._backend.model_type
160-
161-
def get_sel_type(self) -> list[int]:
162-
return self._backend.get_sel_type()
163-
164-
def get_numb_dos(self) -> int:
165-
return self._backend.get_numb_dos()
166-
167-
def get_has_efield(self) -> bool:
168-
return self._backend.get_has_efield()
169-
170-
def get_has_spin(self) -> bool:
171-
return self._backend.get_has_spin()
172-
173-
def get_has_hessian(self) -> bool:
174-
return self._backend.get_has_hessian()
175-
176-
def get_var_name(self) -> str:
177-
return self._backend.get_var_name()
178-
179-
def get_ntypes_spin(self) -> int:
180-
return self._backend.get_ntypes_spin()
181-
182-
def get_model(self) -> Any:
183-
return self._backend.get_model()
184-
185-
return PretrainedDeepEvalBackend
41+
class PretrainedDeepEvalBackend(DeepEvalBackend):
42+
"""Resolve alias and delegate to backend selected by resolved model path."""
43+
44+
def __init__(
45+
self,
46+
model_file: str,
47+
output_def: object,
48+
*args: object,
49+
auto_batch_size: object = True,
50+
neighbor_list: object | None = None,
51+
**kwargs: object,
52+
) -> None:
53+
model_name = parse_pretrained_alias(model_file)
54+
resolved = str(resolve_model_path(model_name))
55+
56+
# DeepEvalBackend.__new__ dispatches by resolved suffix (.pt/.pb/.dp...)
57+
self._backend = DeepEvalBackend(
58+
resolved,
59+
output_def,
60+
*args,
61+
auto_batch_size=auto_batch_size,
62+
neighbor_list=neighbor_list,
63+
**kwargs,
64+
)
65+
66+
def eval(
67+
self,
68+
coords: np.ndarray,
69+
cells: np.ndarray | None,
70+
atom_types: np.ndarray,
71+
atomic: bool = False,
72+
fparam: np.ndarray | None = None,
73+
aparam: np.ndarray | None = None,
74+
**kwargs: Any,
75+
) -> dict[str, np.ndarray]:
76+
return self._backend.eval(
77+
coords,
78+
cells,
79+
atom_types,
80+
atomic,
81+
fparam=fparam,
82+
aparam=aparam,
83+
**kwargs,
84+
)
85+
86+
def eval_descriptor(
87+
self,
88+
coords: np.ndarray,
89+
cells: np.ndarray | None,
90+
atom_types: np.ndarray,
91+
fparam: np.ndarray | None = None,
92+
aparam: np.ndarray | None = None,
93+
efield: np.ndarray | None = None,
94+
mixed_type: bool = False,
95+
**kwargs: Any,
96+
) -> np.ndarray:
97+
return self._backend.eval_descriptor(
98+
coords,
99+
cells,
100+
atom_types,
101+
fparam=fparam,
102+
aparam=aparam,
103+
efield=efield,
104+
mixed_type=mixed_type,
105+
**kwargs,
106+
)
107+
108+
def eval_fitting_last_layer(
109+
self,
110+
coords: np.ndarray,
111+
cells: np.ndarray | None,
112+
atom_types: np.ndarray,
113+
fparam: np.ndarray | None = None,
114+
aparam: np.ndarray | None = None,
115+
**kwargs: Any,
116+
) -> np.ndarray:
117+
return self._backend.eval_fitting_last_layer(
118+
coords,
119+
cells,
120+
atom_types,
121+
fparam=fparam,
122+
aparam=aparam,
123+
**kwargs,
124+
)
125+
126+
def get_rcut(self) -> float:
127+
return self._backend.get_rcut()
128+
129+
def get_ntypes(self) -> int:
130+
return self._backend.get_ntypes()
131+
132+
def get_type_map(self) -> list[str]:
133+
return self._backend.get_type_map()
134+
135+
def get_dim_fparam(self) -> int:
136+
return self._backend.get_dim_fparam()
137+
138+
def has_default_fparam(self) -> bool:
139+
return self._backend.has_default_fparam()
140+
141+
def get_dim_aparam(self) -> int:
142+
return self._backend.get_dim_aparam()
143+
144+
@property
145+
def model_type(self) -> type[DeepEval]:
146+
return self._backend.model_type
147+
148+
def get_sel_type(self) -> list[int]:
149+
return self._backend.get_sel_type()
150+
151+
def get_numb_dos(self) -> int:
152+
return self._backend.get_numb_dos()
153+
154+
def get_has_efield(self) -> bool:
155+
return self._backend.get_has_efield()
156+
157+
def get_has_spin(self) -> bool:
158+
return self._backend.get_has_spin()
159+
160+
def get_has_hessian(self) -> bool:
161+
return self._backend.get_has_hessian()
162+
163+
def get_var_name(self) -> str:
164+
return self._backend.get_var_name()
165+
166+
def get_ntypes_spin(self) -> int:
167+
return self._backend.get_ntypes_spin()
168+
169+
def get_model(self) -> Any:
170+
return self._backend.get_model()

source/tests/common/test_pretrained_backend.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,4 +38,8 @@ def test_parse_pretrained_alias_invalid(self) -> None:
3838
parse_pretrained_alias("DPA-3.2-5M.pt")
3939

4040
def test_deep_eval_property(self) -> None:
41-
self.assertIsNotNone(PretrainedBackend().deep_eval)
41+
from deepmd.pretrained.deep_eval import (
42+
PretrainedDeepEvalBackend,
43+
)
44+
45+
self.assertIs(PretrainedBackend().deep_eval, PretrainedDeepEvalBackend)

0 commit comments

Comments
 (0)