Skip to content

Commit 0b3dfc8

Browse files
committed
Revert "perf: use multithread to accelarate stat computing and loading"
1 parent ac3dbbc commit 0b3dfc8

2 files changed

Lines changed: 49 additions & 112 deletions

File tree

deepmd/pt/utils/stat.py

Lines changed: 41 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,8 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22
import logging
3-
import os
43
from collections import (
54
defaultdict,
65
)
7-
from concurrent.futures import (
8-
ThreadPoolExecutor,
9-
)
106
from typing import (
117
Any,
128
Callable,
@@ -43,7 +39,7 @@
4339
def make_stat_input(
4440
datasets: list[Any], dataloaders: list[Any], nbatches: int
4541
) -> dict[str, Any]:
46-
"""Pack data for statistics in parallel.
42+
"""Pack data for statistics.
4743
4844
Args:
4945
- dataset: A list of dataset to analyze.
@@ -53,83 +49,49 @@ def make_stat_input(
5349
-------
5450
- a list of dicts, each of which contains data from a system
5551
"""
56-
log.info(f"Packing data for statistics from {len(datasets)} systems")
57-
dataloader_lens = [len(dl) for dl in dataloaders]
58-
args_list = [
59-
(dataloaders[i], nbatches, dataloader_lens[i]) for i in range(len(datasets))
60-
]
61-
6252
lst = []
63-
# I/O intensive, set a larger number of workers
64-
with ThreadPoolExecutor(min(128, (os.cpu_count() or 1) * 6)) as executor:
65-
lst = list(executor.map(_process_one_dataset, args_list))
66-
log.info("Finished packing data.")
67-
return lst
68-
69-
70-
def _process_one_dataset(args: tuple[Any, int, int]) -> dict[str, Any]:
71-
"""
72-
Helper function to process a single dataset's dataloader for statistics.
73-
Designed to be called in parallel by a ThreadPoolExecutor.
74-
75-
Parameters
76-
----------
77-
args : tuple(Any, int, int)
78-
A tuple containing (dataloader, nbatches, dataloader_len)
79-
80-
Returns
81-
-------
82-
dict[str, Any]
83-
The processed sys_stat dictionary for one dataset.
84-
"""
85-
dataloader, nbatches, dataloader_len = args
86-
sys_stat = {}
87-
88-
with torch.device("cpu"):
89-
iterator = iter(dataloader)
90-
numb_batches = min(nbatches, dataloader_len)
91-
92-
for _ in range(numb_batches):
93-
try:
94-
stat_data = next(iterator)
95-
except StopIteration:
96-
iterator = iter(dataloader)
97-
stat_data = next(iterator)
98-
99-
if (
100-
"find_fparam" in stat_data
101-
and "fparam" in stat_data
102-
and stat_data["find_fparam"] == 0.0
103-
):
104-
# for model using default fparam
105-
stat_data.pop("fparam")
106-
stat_data.pop("find_fparam")
107-
108-
for dd in stat_data:
109-
if stat_data[dd] is None:
110-
sys_stat[dd] = None
111-
elif isinstance(stat_data[dd], torch.Tensor):
112-
if dd not in sys_stat:
113-
sys_stat[dd] = []
114-
sys_stat[dd].append(stat_data[dd])
115-
elif isinstance(stat_data[dd], np.float32):
116-
sys_stat[dd] = stat_data[dd]
117-
else:
118-
pass
119-
120-
for key in sys_stat:
121-
if isinstance(sys_stat[key], np.float32):
122-
pass
123-
elif isinstance(sys_stat[key], list):
124-
if len(sys_stat[key]) == 0 or sys_stat[key][0] is None:
53+
log.info(f"Packing data for statistics from {len(datasets)} systems")
54+
for i in range(len(datasets)):
55+
sys_stat = {}
56+
with torch.device("cpu"):
57+
iterator = iter(dataloaders[i])
58+
numb_batches = min(nbatches, len(dataloaders[i]))
59+
for _ in range(numb_batches):
60+
try:
61+
stat_data = next(iterator)
62+
except StopIteration:
63+
iterator = iter(dataloaders[i])
64+
stat_data = next(iterator)
65+
if (
66+
"find_fparam" in stat_data
67+
and "fparam" in stat_data
68+
and stat_data["find_fparam"] == 0.0
69+
):
70+
# for model using default fparam
71+
stat_data.pop("fparam")
72+
stat_data.pop("find_fparam")
73+
for dd in stat_data:
74+
if stat_data[dd] is None:
75+
sys_stat[dd] = None
76+
elif isinstance(stat_data[dd], torch.Tensor):
77+
if dd not in sys_stat:
78+
sys_stat[dd] = []
79+
sys_stat[dd].append(stat_data[dd])
80+
elif isinstance(stat_data[dd], np.float32):
81+
sys_stat[dd] = stat_data[dd]
82+
else:
83+
pass
84+
85+
for key in sys_stat:
86+
if isinstance(sys_stat[key], np.float32):
87+
pass
88+
elif sys_stat[key] is None or sys_stat[key][0] is None:
12589
sys_stat[key] = None
126-
else:
90+
elif isinstance(stat_data[dd], torch.Tensor):
12791
sys_stat[key] = torch.cat(sys_stat[key], dim=0)
128-
elif sys_stat[key] is None:
129-
pass
130-
131-
dict_to_device(sys_stat)
132-
return sys_stat
92+
dict_to_device(sys_stat)
93+
lst.append(sys_stat)
94+
return lst
13395

13496

13597
def _restore_from_file(

deepmd/utils/env_mat_stat.py

Lines changed: 8 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22
import logging
3-
import os
43
from abc import (
54
ABC,
65
abstractmethod,
@@ -11,9 +10,6 @@
1110
from collections.abc import (
1211
Iterator,
1312
)
14-
from concurrent.futures import (
15-
ThreadPoolExecutor,
16-
)
1713
from typing import (
1814
Optional,
1915
)
@@ -146,7 +142,7 @@ def save_stats(self, path: DPPath) -> None:
146142
(path / kk).save_numpy(np.array([vv.number, vv.sum, vv.squared_sum]))
147143

148144
def load_stats(self, path: DPPath) -> None:
149-
"""Load the statistics of the environment matrix in parallel.
145+
"""Load the statistics of the environment matrix.
150146
151147
Parameters
152148
----------
@@ -155,18 +151,13 @@ def load_stats(self, path: DPPath) -> None:
155151
"""
156152
if len(self.stats) > 0:
157153
raise ValueError("The statistics has already been computed.")
158-
159-
files_to_load = list(path.glob("*"))
160-
161-
if not files_to_load:
162-
raise ValueError(f"No statistics files found in {path}.")
163-
164-
with ThreadPoolExecutor(min(64, (os.cpu_count() or 1) * 4)) as executor:
165-
results = executor.map(self._load_stat_file, files_to_load)
166-
167-
for name, stat_item in results:
168-
if stat_item is not None:
169-
self.stats[name] = stat_item
154+
for kk in path.glob("*"):
155+
arr = kk.load_numpy()
156+
self.stats[kk.name] = StatItem(
157+
number=arr[0],
158+
sum=arr[1],
159+
squared_sum=arr[2],
160+
)
170161

171162
def load_or_compute_stats(
172163
self, data: list[dict[str, np.ndarray]], path: Optional[DPPath] = None
@@ -225,19 +216,3 @@ def get_std(
225216
kk: vv.compute_std(default=default, protection=protection)
226217
for kk, vv in self.stats.items()
227218
}
228-
229-
@staticmethod
230-
def _load_stat_file(file_path: DPPath) -> tuple[str, StatItem]:
231-
"""Helper function for parallel loading of stat files."""
232-
try:
233-
arr = file_path.load_numpy()
234-
if arr.shape == (3,):
235-
return file_path.name, StatItem(
236-
number=arr[0], sum=arr[1], squared_sum=arr[2]
237-
)
238-
else:
239-
log.warning(f"Skipping malformed stat file: {file_path.name}")
240-
return file_path.name, None
241-
except Exception as e:
242-
log.warning(f"Failed to load stat file {file_path.name}: {e}")
243-
return file_path.name, None

0 commit comments

Comments
 (0)