Skip to content

Commit 96230a3

Browse files
committed
bug fix
1 parent 256a92b commit 96230a3

3 files changed

Lines changed: 13 additions & 6 deletions

File tree

deepmd/pt/utils/stat.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22
import logging
3+
import os
34
from collections import (
45
defaultdict,
56
)
@@ -60,7 +61,7 @@ def make_stat_input(
6061

6162
lst = []
6263
# I/O intensive, set a larger number of workers
63-
with ThreadPoolExecutor(max_workers=256) as executor:
64+
with ThreadPoolExecutor(min(128, (os.cpu_count() or 1) * 6)) as executor:
6465
lst = list(executor.map(_process_one_dataset, args_list))
6566
log.info("Finished packing data.")
6667
return lst
@@ -120,7 +121,7 @@ def _process_one_dataset(args: tuple[Any, int, int]) -> dict[str, Any]:
120121
if isinstance(sys_stat[key], np.float32):
121122
pass
122123
elif isinstance(sys_stat[key], list):
123-
if sys_stat[key][0] is None:
124+
if len(sys_stat[key]) == 0 or sys_stat[key][0] is None:
124125
sys_stat[key] = None
125126
else:
126127
sys_stat[key] = torch.cat(sys_stat[key], dim=0)

deepmd/utils/data.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -341,8 +341,13 @@ def get_numb_set(self) -> int:
341341

342342
def get_numb_batch(self, batch_size: int, set_idx: int) -> int:
343343
"""Get the number of batches in a set."""
344-
# Directly obtain the number of frames to avoid loading the entire dataset
345-
nframes = self._get_nframes(self.dirs[set_idx])
344+
set_name = self.dirs[set_idx]
345+
if isinstance(set_name, DPH5Path):
346+
data = self._load_set(set_name)
347+
nframes = data["coord"].shape[0]
348+
else:
349+
# Directly obtain the number of frames to avoid loading the entire dataset
350+
nframes = self._get_nframes(set_name)
346351
ret = nframes // batch_size
347352
if ret == 0:
348353
ret = 1
@@ -587,7 +592,7 @@ def _get_nframes(self, set_name: DPPath) -> int:
587592
if not isinstance(set_name, DPPath):
588593
set_name = DPPath(set_name)
589594
path = set_name / "coord.npy"
590-
# Read only the header to get shape without creating memmap
595+
# Read only the header to get shape
591596
with open(str(path), "rb") as f:
592597
version = np.lib.format.read_magic(f)
593598
if version[0] == 1:

deepmd/utils/env_mat_stat.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22
import logging
3+
import os
34
from abc import (
45
ABC,
56
abstractmethod,
@@ -160,7 +161,7 @@ def load_stats(self, path: DPPath) -> None:
160161
if not files_to_load:
161162
raise ValueError(f"No statistics files found in {path}.")
162163

163-
with ThreadPoolExecutor(max_workers=128) as executor:
164+
with ThreadPoolExecutor(min(64, (os.cpu_count() or 1) * 4)) as executor:
164165
results = executor.map(self._load_stat_file, files_to_load)
165166

166167
for name, stat_item in results:

0 commit comments

Comments
 (0)