11# SPDX-License-Identifier: LGPL-3.0-or-later
2- """Backend helper for `*.pretrained` model aliases."""
2+ """DeepEval adapter for `*.pretrained` model aliases."""
33
44from __future__ import (
55 annotations ,
66)
77
8- from functools import (
9- lru_cache ,
10- )
118from pathlib import (
129 Path ,
1310)
1613 Any ,
1714)
1815
16+ from deepmd .infer .deep_eval import (
17+ DeepEval ,
18+ DeepEvalBackend ,
19+ )
1920from deepmd .pretrained .download import (
2021 resolve_model_path ,
2122)
22-
2323if TYPE_CHECKING :
2424 import numpy as np
2525
26- from deepmd .infer .deep_eval import (
27- DeepEval ,
28- DeepEvalBackend ,
29- )
30-
3126
3227def parse_pretrained_alias (model_file : str ) -> str :
3328 """Extract model name from ``*.pretrained`` alias string."""
@@ -43,143 +38,133 @@ def parse_pretrained_alias(model_file: str) -> str:
4338 return model_name
4439
4540
46- @lru_cache (maxsize = 1 )
47- def get_pretrained_deep_eval_backend () -> type [DeepEvalBackend ]:
48- """Build and cache the concrete DeepEval adapter lazily."""
49- # Avoid circular import when deepmd backend entrypoints are loading.
50- from deepmd .infer .deep_eval import (
51- DeepEvalBackend ,
52- )
53-
54- class PretrainedDeepEvalBackend (DeepEvalBackend ):
55- """Resolve alias and delegate to backend selected by resolved model path."""
56-
57- def __init__ (
58- self ,
59- model_file : str ,
60- output_def : object ,
61- * args : object ,
62- auto_batch_size : object = True ,
63- neighbor_list : object | None = None ,
64- ** kwargs : object ,
65- ) -> None :
66- model_name = parse_pretrained_alias (model_file )
67- resolved = str (resolve_model_path (model_name ))
68-
69- # DeepEvalBackend.__new__ dispatches by resolved suffix (.pt/.pb/.dp...)
70- self ._backend = DeepEvalBackend (
71- resolved ,
72- output_def ,
73- * args ,
74- auto_batch_size = auto_batch_size ,
75- neighbor_list = neighbor_list ,
76- ** kwargs ,
77- )
78-
79- def eval (
80- self ,
81- coords : np .ndarray ,
82- cells : np .ndarray | None ,
83- atom_types : np .ndarray ,
84- atomic : bool = False ,
85- fparam : np .ndarray | None = None ,
86- aparam : np .ndarray | None = None ,
87- ** kwargs : Any ,
88- ) -> dict [str , np .ndarray ]:
89- return self ._backend .eval (
90- coords ,
91- cells ,
92- atom_types ,
93- atomic ,
94- fparam = fparam ,
95- aparam = aparam ,
96- ** kwargs ,
97- )
98-
99- def eval_descriptor (
100- self ,
101- coords : np .ndarray ,
102- cells : np .ndarray | None ,
103- atom_types : np .ndarray ,
104- fparam : np .ndarray | None = None ,
105- aparam : np .ndarray | None = None ,
106- efield : np .ndarray | None = None ,
107- mixed_type : bool = False ,
108- ** kwargs : Any ,
109- ) -> np .ndarray :
110- return self ._backend .eval_descriptor (
111- coords ,
112- cells ,
113- atom_types ,
114- fparam = fparam ,
115- aparam = aparam ,
116- efield = efield ,
117- mixed_type = mixed_type ,
118- ** kwargs ,
119- )
120-
121- def eval_fitting_last_layer (
122- self ,
123- coords : np .ndarray ,
124- cells : np .ndarray | None ,
125- atom_types : np .ndarray ,
126- fparam : np .ndarray | None = None ,
127- aparam : np .ndarray | None = None ,
128- ** kwargs : Any ,
129- ) -> np .ndarray :
130- return self ._backend .eval_fitting_last_layer (
131- coords ,
132- cells ,
133- atom_types ,
134- fparam = fparam ,
135- aparam = aparam ,
136- ** kwargs ,
137- )
138-
139- def get_rcut (self ) -> float :
140- return self ._backend .get_rcut ()
141-
142- def get_ntypes (self ) -> int :
143- return self ._backend .get_ntypes ()
144-
145- def get_type_map (self ) -> list [str ]:
146- return self ._backend .get_type_map ()
147-
148- def get_dim_fparam (self ) -> int :
149- return self ._backend .get_dim_fparam ()
150-
151- def has_default_fparam (self ) -> bool :
152- return self ._backend .has_default_fparam ()
153-
154- def get_dim_aparam (self ) -> int :
155- return self ._backend .get_dim_aparam ()
156-
157- @property
158- def model_type (self ) -> type [DeepEval ]:
159- return self ._backend .model_type
160-
161- def get_sel_type (self ) -> list [int ]:
162- return self ._backend .get_sel_type ()
163-
164- def get_numb_dos (self ) -> int :
165- return self ._backend .get_numb_dos ()
166-
167- def get_has_efield (self ) -> bool :
168- return self ._backend .get_has_efield ()
169-
170- def get_has_spin (self ) -> bool :
171- return self ._backend .get_has_spin ()
172-
173- def get_has_hessian (self ) -> bool :
174- return self ._backend .get_has_hessian ()
175-
176- def get_var_name (self ) -> str :
177- return self ._backend .get_var_name ()
178-
179- def get_ntypes_spin (self ) -> int :
180- return self ._backend .get_ntypes_spin ()
181-
182- def get_model (self ) -> Any :
183- return self ._backend .get_model ()
184-
185- return PretrainedDeepEvalBackend
41+ class PretrainedDeepEvalBackend (DeepEvalBackend ):
42+ """Resolve alias and delegate to backend selected by resolved model path."""
43+
44+ def __init__ (
45+ self ,
46+ model_file : str ,
47+ output_def : object ,
48+ * args : object ,
49+ auto_batch_size : object = True ,
50+ neighbor_list : object | None = None ,
51+ ** kwargs : object ,
52+ ) -> None :
53+ model_name = parse_pretrained_alias (model_file )
54+ resolved = str (resolve_model_path (model_name ))
55+
56+ # DeepEvalBackend.__new__ dispatches by resolved suffix (.pt/.pb/.dp...)
57+ self ._backend = DeepEvalBackend (
58+ resolved ,
59+ output_def ,
60+ * args ,
61+ auto_batch_size = auto_batch_size ,
62+ neighbor_list = neighbor_list ,
63+ ** kwargs ,
64+ )
65+
66+ def eval (
67+ self ,
68+ coords : np .ndarray ,
69+ cells : np .ndarray | None ,
70+ atom_types : np .ndarray ,
71+ atomic : bool = False ,
72+ fparam : np .ndarray | None = None ,
73+ aparam : np .ndarray | None = None ,
74+ ** kwargs : Any ,
75+ ) -> dict [str , np .ndarray ]:
76+ return self ._backend .eval (
77+ coords ,
78+ cells ,
79+ atom_types ,
80+ atomic ,
81+ fparam = fparam ,
82+ aparam = aparam ,
83+ ** kwargs ,
84+ )
85+
86+ def eval_descriptor (
87+ self ,
88+ coords : np .ndarray ,
89+ cells : np .ndarray | None ,
90+ atom_types : np .ndarray ,
91+ fparam : np .ndarray | None = None ,
92+ aparam : np .ndarray | None = None ,
93+ efield : np .ndarray | None = None ,
94+ mixed_type : bool = False ,
95+ ** kwargs : Any ,
96+ ) -> np .ndarray :
97+ return self ._backend .eval_descriptor (
98+ coords ,
99+ cells ,
100+ atom_types ,
101+ fparam = fparam ,
102+ aparam = aparam ,
103+ efield = efield ,
104+ mixed_type = mixed_type ,
105+ ** kwargs ,
106+ )
107+
108+ def eval_fitting_last_layer (
109+ self ,
110+ coords : np .ndarray ,
111+ cells : np .ndarray | None ,
112+ atom_types : np .ndarray ,
113+ fparam : np .ndarray | None = None ,
114+ aparam : np .ndarray | None = None ,
115+ ** kwargs : Any ,
116+ ) -> np .ndarray :
117+ return self ._backend .eval_fitting_last_layer (
118+ coords ,
119+ cells ,
120+ atom_types ,
121+ fparam = fparam ,
122+ aparam = aparam ,
123+ ** kwargs ,
124+ )
125+
126+ def get_rcut (self ) -> float :
127+ return self ._backend .get_rcut ()
128+
129+ def get_ntypes (self ) -> int :
130+ return self ._backend .get_ntypes ()
131+
132+ def get_type_map (self ) -> list [str ]:
133+ return self ._backend .get_type_map ()
134+
135+ def get_dim_fparam (self ) -> int :
136+ return self ._backend .get_dim_fparam ()
137+
138+ def has_default_fparam (self ) -> bool :
139+ return self ._backend .has_default_fparam ()
140+
141+ def get_dim_aparam (self ) -> int :
142+ return self ._backend .get_dim_aparam ()
143+
144+ @property
145+ def model_type (self ) -> type [DeepEval ]:
146+ return self ._backend .model_type
147+
148+ def get_sel_type (self ) -> list [int ]:
149+ return self ._backend .get_sel_type ()
150+
151+ def get_numb_dos (self ) -> int :
152+ return self ._backend .get_numb_dos ()
153+
154+ def get_has_efield (self ) -> bool :
155+ return self ._backend .get_has_efield ()
156+
157+ def get_has_spin (self ) -> bool :
158+ return self ._backend .get_has_spin ()
159+
160+ def get_has_hessian (self ) -> bool :
161+ return self ._backend .get_has_hessian ()
162+
163+ def get_var_name (self ) -> str :
164+ return self ._backend .get_var_name ()
165+
166+ def get_ntypes_spin (self ) -> int :
167+ return self ._backend .get_ntypes_spin ()
168+
169+ def get_model (self ) -> Any :
170+ return self ._backend .get_model ()
0 commit comments