1313
1414from deepmd .dpmodel .array_api import (
1515 Array ,
16+ xp_take_first_n ,
1617)
1718from deepmd .dpmodel .common import (
1819 NativeOP ,
@@ -62,6 +63,45 @@ def __init__(
6263 self .rcond = rcond
6364 self .preset_out_bias = preset_out_bias
6465 self .data_stat_protect = data_stat_protect
66+ self ._observed_type : list [str ] | None = None
67+
68+ @property
69+ def observed_type (self ) -> list [str ] | None :
70+ """Get the observed element type list from data statistics."""
71+ return self ._observed_type
72+
73+ def _collect_and_set_observed_type (
74+ self ,
75+ sampled_func : Callable [[], list [dict ]],
76+ stat_file_path : DPPath | None ,
77+ preset_observed_type : list [str ] | None ,
78+ ) -> None :
79+ """Collect observed types with priority: preset > stat_file > compute.
80+
81+ Parameters
82+ ----------
83+ sampled_func
84+ The lazy sampled function to get data frames.
85+ stat_file_path
86+ The path to the statistics files (should already include type_map suffix).
87+ preset_observed_type
88+ User-specified observed types that take highest priority.
89+ """
90+ from deepmd .dpmodel .utils .stat import (
91+ _restore_observed_type_from_file ,
92+ _save_observed_type_to_file ,
93+ collect_observed_types ,
94+ )
95+
96+ if preset_observed_type is not None :
97+ self ._observed_type = preset_observed_type
98+ else :
99+ observed = _restore_observed_type_from_file (stat_file_path )
100+ if observed is None :
101+ sampled = sampled_func ()
102+ observed = collect_observed_types (sampled , self .type_map )
103+ _save_observed_type_to_file (stat_file_path , observed )
104+ self ._observed_type = observed
65105
66106 def init_out_stat (self ) -> None :
67107 """Initialize the output bias."""
@@ -211,7 +251,7 @@ def forward_common_atomic(
211251 """
212252 xp = array_api_compat .array_namespace (extended_coord , extended_atype , nlist )
213253 _ , nloc , _ = nlist .shape
214- atype = extended_atype [:, : nloc ]
254+ atype = xp_take_first_n ( extended_atype , 1 , nloc )
215255 if self .pair_excl is not None :
216256 pair_mask = self .pair_excl .build_type_exclude_mask (nlist , extended_atype )
217257 # exclude neighbors in the nlist
@@ -229,7 +269,7 @@ def forward_common_atomic(
229269 ret_dict = self .apply_out_stat (ret_dict , atype )
230270
231271 # nf x nloc
232- atom_mask = ext_atom_mask [:, : nloc ]
272+ atom_mask = xp_take_first_n ( ext_atom_mask , 1 , nloc )
233273 if self .atom_excl is not None :
234274 atom_mask = xp .logical_and (
235275 atom_mask , self .atom_excl .build_type_exclude_mask (atype )
@@ -271,6 +311,29 @@ def get_compute_stats_distinguish_types(self) -> bool:
271311 """Get whether the fitting net computes stats which are not distinguished between different types of atoms."""
272312 return True
273313
314+ def compute_or_load_stat (
315+ self ,
316+ sampled_func : Callable [[], list [dict ]],
317+ stat_file_path : DPPath | None = None ,
318+ compute_or_load_out_stat : bool = True ,
319+ preset_observed_type : list [str ] | None = None ,
320+ ) -> None :
321+ """Compute or load the statistics parameters of the model,
322+ such as mean and standard deviation of descriptors or the energy bias of the fitting net.
323+
324+ Parameters
325+ ----------
326+ sampled_func
327+ The lazy sampled function to get data frames from different data systems.
328+ stat_file_path
329+ The path to the stat file.
330+ compute_or_load_out_stat : bool
331+ Whether to compute the output statistics.
332+ If False, it will only compute the input statistics
333+ (e.g. mean and standard deviation of descriptors).
334+ """
335+ raise NotImplementedError
336+
274337 def compute_or_load_out_stat (
275338 self ,
276339 merged : Callable [[], list [dict ]] | list [dict ],
@@ -332,19 +395,19 @@ def wrapped_sampler() -> list[dict]:
332395 atom_exclude_types = self .atom_excl .get_exclude_types ()
333396 for sample in sampled :
334397 sample ["atom_exclude_types" ] = list (atom_exclude_types )
335- if (
336- "find_fparam" not in sampled [0 ]
337- and "fparam" not in sampled [0 ]
338- and self .has_default_fparam ()
339- ):
398+ # For systems where fparam is missing (find_fparam == 0),
399+ # fill with default fparam if available and mark as found.
400+ if self .has_default_fparam ():
340401 default_fparam = self .get_default_fparam ()
341402 if default_fparam is not None :
342403 default_fparam_np = np .array (default_fparam )
343404 for sample in sampled :
344- nframe = sample ["atype" ].shape [0 ]
345- sample ["fparam" ] = np .tile (
346- default_fparam_np .reshape (1 , - 1 ), (nframe , 1 )
347- )
405+ if "find_fparam" in sample and not sample ["find_fparam" ]:
406+ nframe = sample ["atype" ].shape [0 ]
407+ sample ["fparam" ] = np .tile (
408+ default_fparam_np .reshape (1 , - 1 ), (nframe , 1 )
409+ )
410+ sample ["find_fparam" ] = np .bool_ (True )
348411 return sampled
349412
350413 return wrapped_sampler
0 commit comments