Skip to content

Commit 2bb190f

Browse files
committed
fix docstr
1 parent 7770428 commit 2bb190f

4 files changed

Lines changed: 21 additions & 9 deletions

File tree

deepmd/pt/model/atomic_model/base_atomic_model.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -363,10 +363,10 @@ def compute_or_load_stat(
363363
self,
364364
merged: Union[Callable[[], list[dict]], list[dict]],
365365
stat_file_path: Optional[DPPath] = None,
366-
compute_out_stat: bool = True,
366+
compute_or_load_out_stat: bool = True,
367367
) -> NoReturn:
368368
"""
369-
Compute the output statistics (e.g. energy bias) for the fitting net from packed data.
369+
Compute the input and output statistics (e.g. energy bias) for the model from packed data.
370370
371371
Parameters
372372
----------
@@ -379,6 +379,9 @@ def compute_or_load_stat(
379379
the lazy function helps by only sampling once.
380380
stat_file_path : Optional[DPPath]
381381
The path to the stat file.
382+
compute_or_load_out_stat : bool
383+
Whether to compute the output statistics.
384+
If False, it will only compute the input statistics (e.g. mean and standard deviation of descriptors).
382385
383386
"""
384387
raise NotImplementedError

deepmd/pt/model/atomic_model/dp_atomic_model.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ def compute_or_load_stat(
285285
self,
286286
sampled_func,
287287
stat_file_path: Optional[DPPath] = None,
288-
compute_out_stat: bool = True,
288+
compute_or_load_out_stat: bool = True,
289289
) -> None:
290290
"""
291291
Compute or load the statistics parameters of the model,
@@ -301,6 +301,9 @@ def compute_or_load_stat(
301301
The lazy sampled function to get data frames from different data systems.
302302
stat_file_path
303303
The dictionary of paths to the statistics files.
304+
compute_or_load_out_stat : bool
305+
Whether to compute the output statistics.
306+
If False, it will only compute the input statistics (e.g. mean and standard deviation of descriptors).
304307
"""
305308
if stat_file_path is not None and self.type_map is not None:
306309
# descriptors and fitting net with different type_map
@@ -324,7 +327,7 @@ def wrapped_sampler():
324327
self.fitting_net.compute_input_stats(
325328
wrapped_sampler, protection=self.data_stat_protect
326329
)
327-
if compute_out_stat:
330+
if compute_or_load_out_stat:
328331
self.compute_or_load_out_stat(wrapped_sampler, stat_file_path)
329332

330333
def get_dim_fparam(self) -> int:

deepmd/pt/model/atomic_model/linear_atomic_model.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -472,7 +472,7 @@ def compute_or_load_stat(
472472
self,
473473
sampled_func,
474474
stat_file_path: Optional[DPPath] = None,
475-
compute_out_stat: bool = True,
475+
compute_or_load_out_stat: bool = True,
476476
) -> None:
477477
"""
478478
Compute or load the statistics parameters of the model,
@@ -488,10 +488,13 @@ def compute_or_load_stat(
488488
The lazy sampled function to get data frames from different data systems.
489489
stat_file_path
490490
The dictionary of paths to the statistics files.
491+
compute_or_load_out_stat : bool
492+
Whether to compute the output statistics.
493+
If False, it will only compute the input statistics (e.g. mean and standard deviation of descriptors).
491494
"""
492495
for md in self.models:
493496
md.compute_or_load_stat(
494-
sampled_func, stat_file_path, compute_out_stat=False
497+
sampled_func, stat_file_path, compute_or_load_out_stat=False
495498
)
496499

497500
if stat_file_path is not None and self.type_map is not None:

deepmd/pt/model/atomic_model/pairtab_atomic_model.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -226,10 +226,10 @@ def compute_or_load_stat(
226226
self,
227227
merged: Union[Callable[[], list[dict]], list[dict]],
228228
stat_file_path: Optional[DPPath] = None,
229-
compute_out_stat: bool = True,
229+
compute_or_load_out_stat: bool = True,
230230
) -> None:
231231
"""
232-
Compute the output statistics (e.g. energy bias) for the fitting net from packed data.
232+
Compute the input and output statistics (e.g. energy bias) for the model from packed data.
233233
234234
Parameters
235235
----------
@@ -242,9 +242,12 @@ def compute_or_load_stat(
242242
the lazy function helps by only sampling once.
243243
stat_file_path : Optional[DPPath]
244244
The path to the stat file.
245+
compute_or_load_out_stat : bool
246+
Whether to compute the output statistics.
247+
If False, it will only compute the input statistics (e.g. mean and standard deviation of descriptors).
245248
246249
"""
247-
if compute_out_stat:
250+
if compute_or_load_out_stat:
248251
self.compute_or_load_out_stat(merged, stat_file_path)
249252

250253
def forward_atomic(

0 commit comments

Comments
 (0)