Skip to content

Commit 404b915

Browse files
committed
Update docstrings for compute_or_load_stat methods
Revised and clarified the docstrings for compute_or_load_stat in both BaseAtomicModel and PairTabAtomicModel to better describe the function parameters and behavior. Updated parameter names and descriptions for improved consistency and readability.
1 parent 2ef9444 commit 404b915

2 files changed

Lines changed: 22 additions & 22 deletions

File tree

deepmd/pt/model/atomic_model/base_atomic_model.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -366,19 +366,19 @@ def compute_or_load_stat(
366366
compute_or_load_out_stat: bool = True,
367367
) -> NoReturn:
368368
"""
369-
Compute the input and output statistics (e.g. energy bias) for the model from packed data.
369+
Compute or load the statistics parameters of the model,
370+
such as mean and standard deviation of descriptors or the energy bias of the fitting net.
371+
When `sampled` is provided, all the statistics parameters will be calculated (or re-calculated for update),
372+
and saved in the `stat_file_path`(s).
373+
When `sampled` is not provided, it will check the existence of `stat_file_path`(s)
374+
and load the calculated statistics parameters.
370375
371376
Parameters
372377
----------
373-
merged : Union[Callable[[], list[dict]], list[dict]]
374-
- list[dict]: A list of data samples from various data systems.
375-
Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor`
376-
originating from the `i`-th data system.
377-
- Callable[[], list[dict]]: A lazy function that returns data samples in the above format
378-
only when needed. Since the sampling process can be slow and memory-intensive,
379-
the lazy function helps by only sampling once.
380-
stat_file_path : Optional[DPPath]
381-
The path to the stat file.
378+
merged
379+
The lazy sampled function to get data frames from different data systems.
380+
stat_file_path
381+
The dictionary of paths to the statistics files.
382382
compute_or_load_out_stat : bool
383383
Whether to compute the output statistics.
384384
If False, it will only compute the input statistics (e.g. mean and standard deviation of descriptors).

deepmd/pt/model/atomic_model/pairtab_atomic_model.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -224,31 +224,31 @@ def deserialize(cls, data) -> "PairTabAtomicModel":
224224

225225
def compute_or_load_stat(
226226
self,
227-
merged: Union[Callable[[], list[dict]], list[dict]],
227+
sampled_func: Union[Callable[[], list[dict]], list[dict]],
228228
stat_file_path: Optional[DPPath] = None,
229229
compute_or_load_out_stat: bool = True,
230230
) -> None:
231231
"""
232-
Compute the input and output statistics (e.g. energy bias) for the model from packed data.
232+
Compute or load the statistics parameters of the model,
233+
such as mean and standard deviation of descriptors or the energy bias of the fitting net.
234+
When `sampled` is provided, all the statistics parameters will be calculated (or re-calculated for update),
235+
and saved in the `stat_file_path`(s).
236+
When `sampled` is not provided, it will check the existence of `stat_file_path`(s)
237+
and load the calculated statistics parameters.
233238
234239
Parameters
235240
----------
236-
merged : Union[Callable[[], list[dict]], list[dict]]
237-
- list[dict]: A list of data samples from various data systems.
238-
Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor`
239-
originating from the `i`-th data system.
240-
- Callable[[], list[dict]]: A lazy function that returns data samples in the above format
241-
only when needed. Since the sampling process can be slow and memory-intensive,
242-
the lazy function helps by only sampling once.
243-
stat_file_path : Optional[DPPath]
244-
The path to the stat file.
241+
sampled_func
242+
The lazy sampled function to get data frames from different data systems.
243+
stat_file_path
244+
The dictionary of paths to the statistics files.
245245
compute_or_load_out_stat : bool
246246
Whether to compute the output statistics.
247247
If False, it will only compute the input statistics (e.g. mean and standard deviation of descriptors).
248248
249249
"""
250250
if compute_or_load_out_stat:
251-
self.compute_or_load_out_stat(merged, stat_file_path)
251+
self.compute_or_load_out_stat(sampled_func, stat_file_path)
252252

253253
def forward_atomic(
254254
self,

0 commit comments

Comments
 (0)