Skip to content

Commit a9d28f7

Browse files
Solve conflict
1 parent 07483a7 commit a9d28f7

2 files changed

Lines changed: 10 additions & 4 deletions

File tree

deepmd/pd/model/atomic_model/dp_atomic_model.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -397,13 +397,14 @@ def wrapped_sampler():
397397
return sampled
398398

399399
self.descriptor.compute_input_stats(wrapped_sampler, stat_file_path)
400-
self.compute_fitting_input_stat(wrapped_sampler)
400+
self.compute_fitting_input_stat(wrapped_sampler, stat_file_path)
401401
if compute_or_load_out_stat:
402402
self.compute_or_load_out_stat(wrapped_sampler, stat_file_path)
403403

404404
def compute_fitting_input_stat(
405405
self,
406406
sample_merged: Union[Callable[[], list[dict]], list[dict]],
407+
stat_file_path: Optional[DPPath] = None,
407408
) -> None:
408409
"""Compute the input statistics (e.g. mean and stddev) for the fittings from packed data.
409410
@@ -416,9 +417,11 @@ def compute_fitting_input_stat(
416417
- Callable[[], list[dict]]: A lazy function that returns data samples in the above format
417418
only when needed. Since the sampling process can be slow and memory-intensive,
418419
the lazy function helps by only sampling once.
420+
stat_file_path : Optional[DPPath]
421+
The dictionary of paths to the statistics files.
419422
"""
420423
self.fitting_net.compute_input_stats(
421-
sample_merged, protection=self.data_stat_protect
424+
sample_merged, protection=self.data_stat_protect, stat_file_path=stat_file_path,
422425
)
423426

424427
def get_dim_fparam(self) -> int:

deepmd/pt/model/atomic_model/dp_atomic_model.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -338,13 +338,14 @@ def wrapped_sampler() -> list[dict]:
338338
return sampled
339339

340340
self.descriptor.compute_input_stats(wrapped_sampler, stat_file_path)
341-
self.compute_fitting_input_stat(wrapped_sampler)
341+
self.compute_fitting_input_stat(wrapped_sampler, stat_file_path)
342342
if compute_or_load_out_stat:
343343
self.compute_or_load_out_stat(wrapped_sampler, stat_file_path)
344344

345345
def compute_fitting_input_stat(
346346
self,
347347
sample_merged: Union[Callable[[], list[dict]], list[dict]],
348+
stat_file_path: Optional[DPPath] = None,
348349
) -> None:
349350
"""Compute the input statistics (e.g. mean and stddev) for the fittings from packed data.
350351
@@ -357,9 +358,11 @@ def compute_fitting_input_stat(
357358
- Callable[[], list[dict]]: A lazy function that returns data samples in the above format
358359
only when needed. Since the sampling process can be slow and memory-intensive,
359360
the lazy function helps by only sampling once.
361+
stat_file_path : Optional[DPPath]
362+
The dictionary of paths to the statistics files.
360363
"""
361364
self.fitting_net.compute_input_stats(
362-
sample_merged, protection=self.data_stat_protect
365+
sample_merged, protection=self.data_stat_protect, stat_file_path=stat_file_path,
363366
)
364367

365368
def get_dim_fparam(self) -> int:

0 commit comments

Comments
 (0)