Skip to content

Commit 30b762e

Browse files
caic99Copilotpre-commit-ci[bot]
authored
fix: remove the use of BufferedIterator (#4737)
This pull request refactors the data loading mechanism in the training module by replacing the `BufferedIterator` class with a simpler `cycle_iterator` function. I was able to reproduce this error on a system of Ni6Fe10 provided by @iProzd . The error log shows there is something strange when running garbage collection. I've tried the following approaches: - Not using GPU for training, but CPU: failed - Set `NUM_WORKERS`=0: worked - Manually run garbage collection `gc.collect()` when a DataLoader finishes its epoch: worked, but ~10% slower <details><summary>Error log</summary> <p> ``` Fatal Python error: Aborted Thread 0x00007f68b9289700 (most recent call first): File "/root/miniconda3/lib/python3.10/threading.py", line 320 in wait File "/root/miniconda3/lib/python3.10/multiprocessing/queues.py", line 231 in _feed File "/root/miniconda3/lib/python3.10/threading.py", line 953 in run File "/root/miniconda3/lib/python3.10/threading.py", line 1016 in _bootstrap_inner File "/root/miniconda3/lib/python3.10/threading.py", line 973 in _bootstrap Current thread 0x00007f6c595d7280 (most recent call first): Garbage-collecting File "/root/miniconda3/lib/python3.10/ast.py", line 99 in _convert File "/root/miniconda3/lib/python3.10/ast.py", line 110 in literal_eval File "/root/miniconda3/lib/python3.10/site-packages/numpy/lib/utils.py", line 1078 in safe_eval File "/root/miniconda3/lib/python3.10/site-packages/numpy/lib/format.py", line 623 in _read_array_header File "/root/miniconda3/lib/python3.10/site-packages/numpy/lib/format.py", line 784 in read_array File "/root/miniconda3/lib/python3.10/site-packages/numpy/lib/npyio.py", line 456 in load File "/aisi/cc/deepmd-kit/deepmd/utils/path.py", line 187 in load_numpy File "/aisi/cc/deepmd-kit/deepmd/utils/data.py", line 634 in _load_data File "/aisi/cc/deepmd-kit/deepmd/utils/data.py", line 526 in _load_set File "/aisi/cc/deepmd-kit/deepmd/utils/data.py", line 251 in get_item_torch File "/aisi/cc/deepmd-kit/deepmd/pt/utils/dataset.py", line 39 in __getitem__ File "/root/miniconda3/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 52 in <listcomp> File "/root/miniconda3/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 52 in fetch File "/root/miniconda3/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 789 in _next_data File "/root/miniconda3/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 733 in __next__ File "/aisi/cc/deepmd-kit/deepmd/pt/utils/dataloader.py", line 204 in __getitem__ File "/root/miniconda3/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 54 in fetch File "/root/miniconda3/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 349 in _worker_loop File "/root/miniconda3/lib/python3.10/multiprocessing/process.py", line 108 in run File "/root/miniconda3/lib/python3.10/multiprocessing/process.py", line 314 in _bootstrap File "/root/miniconda3/lib/python3.10/multiprocessing/popen_fork.py", line 71 in _launch File "/root/miniconda3/lib/python3.10/multiprocessing/popen_fork.py", line 19 in __init__ File "/root/miniconda3/lib/python3.10/multiprocessing/context.py", line 281 in _Popen File "/root/miniconda3/lib/python3.10/multiprocessing/context.py", line 224 in _Popen File "/root/miniconda3/lib/python3.10/multiprocessing/process.py", line 121 in start File "/root/miniconda3/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1171 in __init__ File "/root/miniconda3/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 424 in _get_iterator File "/root/miniconda3/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 493 in __iter__ File "/aisi/cc/deepmd-kit/deepmd/pt/train/training.py", line 1075 in get_data File "/aisi/cc/deepmd-kit/deepmd/pt/train/training.py", line 689 in step File "/aisi/cc/deepmd-kit/deepmd/pt/train/training.py", line 960 in run File "/aisi/cc/deepmd-kit/deepmd/pt/entrypoints/main.py", line 361 in train File "/aisi/cc/deepmd-kit/deepmd/pt/entrypoints/main.py", line 530 in main File "/root/miniconda3/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 355 in wrapper File "/aisi/cc/deepmd-kit/deepmd/main.py", line 930 in main File "/root/miniconda3/bin/dp", line 8 in <module> Extension modules: numpy.core._multiarray_umath, numpy.core._multiarray_tests, numpy.linalg._umath_linalg, numpy.fft._pocketfft_internal, numpy.random._common, numpy.random.bit_generator, numpy.random._bounded_integers, numpy.random._mt19937, numpy.random.mtrand, numpy.random._philox, numpy.random._pcg64, numpy.random._sfc64, numpy.random._generator, torch._C, torch._C._dynamo.autograd_compiler, torch._C._dynamo.eval_frame, torch._C._dynamo.guards, torch._C._dynamo.utils, torch._C._fft, torch._C._linalg, torch._C._nested, torch._C._nn, torch._C._sparse, torch._C._special, h5py._errors, h5py.defs, h5py._objects, h5py.h5, h5py.utils, h5py.h5t, h5py.h5s, h5py.h5ac, h5py.h5p, h5py.h5r, h5py._proxy, h5py._conv, h5py.h5z, h5py.h5a, h5py.h5d, h5py.h5ds, h5py.h5g, h5py.h5i, h5py.h5o, h5py.h5f, h5py.h5fd, h5py.h5pl, h5py.h5l, h5py._selector, yaml._yaml, scipy._lib._ccallback_c, scipy.special._ufuncs_cxx, scipy.special._ufuncs, scipy.special._specfun, scipy.special._comb, scipy.linalg._fblas, scipy.linalg._flapack, scipy.linalg.cython_lapack, scipy.linalg._cythonized_array_utils, scipy.linalg._solve_toeplitz, scipy.linalg._decomp_lu_cython, scipy.linalg._matfuncs_sqrtm_triu, scipy.linalg._matfuncs_expm, scipy.linalg._linalg_pythran, scipy.linalg.cython_blas, scipy.linalg._decomp_update, scipy.sparse._sparsetools, _csparsetools, scipy.sparse._csparsetools, scipy.sparse.linalg._dsolve._superlu, scipy.sparse.linalg._eigen.arpack._arpack, scipy.sparse.linalg._propack._spropack, scipy.sparse.linalg._propack._dpropack, scipy.sparse.linalg._propack._cpropack, scipy.sparse.linalg._propack._zpropack, scipy.sparse.csgraph._tools, scipy.sparse.csgraph._shortest_path, scipy.sparse.csgraph._traversal, scipy.sparse.csgraph._min_spanning_tree, scipy.sparse.csgraph._flow, scipy.sparse.csgraph._matching, scipy.sparse.csgraph._reordering, scipy.special._ellip_harm_2, scipy.interpolate._fitpack, scipy.interpolate._dfitpack, scipy.optimize._group_columns, scipy._lib.messagestream, scipy.optimize._trlib._trlib, scipy.optimize._lbfgsb, _moduleTNC, scipy.optimize._moduleTNC, scipy.optimize._cobyla, scipy.optimize._slsqp, scipy.optimize._minpack, scipy.optimize._lsq.givens_elimination, scipy.optimize._zeros, scipy.optimize._cython_nnls, scipy._lib._uarray._uarray, scipy.linalg._decomp_interpolative, scipy.optimize._bglu_dense, scipy.optimize._lsap, scipy.spatial._ckdtree, scipy.spatial._qhull, scipy.spatial._voronoi, scipy.spatial._distance_wrap, scipy.spatial._hausdorff, scipy.spatial.transform._rotation, scipy.optimize._direct, scipy.interpolate._dierckx, scipy.interpolate._ppoly, scipy.interpolate._interpnd, scipy.interpolate._rbfinterp_pythran, scipy.interpolate._rgi_cython, scipy.interpolate._bspl (total: 113) Traceback (most recent call last): File "/root/miniconda3/bin/dp", line 8, in <module> sys.exit(main()) File "/aisi/cc/deepmd-kit/deepmd/main.py", line 930, in main deepmd_main(args) File "/root/miniconda3/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 355, in wrapper return f(*args, **kwargs) File "/aisi/cc/deepmd-kit/deepmd/pt/entrypoints/main.py", line 530, in main train( File "/aisi/cc/deepmd-kit/deepmd/pt/entrypoints/main.py", line 361, in train trainer.run() File "/aisi/cc/deepmd-kit/deepmd/pt/train/training.py", line 960, in run step(step_id) File "/aisi/cc/deepmd-kit/deepmd/pt/train/training.py", line 705, in step loss.backward() File "/root/miniconda3/lib/python3.10/site-packages/torch/_tensor.py", line 648, in backward torch.autograd.backward( File "/root/miniconda3/lib/python3.10/site-packages/torch/autograd/__init__.py", line 353, in backward _engine_run_backward( File "/root/miniconda3/lib/python3.10/site-packages/torch/autograd/graph.py", line 824, in _engine_run_backward return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass File "/root/miniconda3/lib/python3.10/site-packages/torch/utils/data/_utils/signal_handling.py", line 73, in handler _error_if_any_worker_fails() RuntimeError: DataLoader worker (pid 434731) is killed by signal: Aborted. ``` </p> </details> The problem couples with pytorch tensor and python threads and pipes, which is hard to locate the root cause. I tested this PR on single-task training and multi-task training, and the training speed (in s/1000 steps) is almost the same: |data|before|after| |---|---|---| |omat|235.52|235.89| |multi-task pretraining|290.63|290.00| Fix #4586 <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **Refactor** - Enhanced data loading with an infinite cycling iterator for uninterrupted batch retrieval during training and validation. - Removed background prefetching and threading to simplify data loading utilities. - **Style** - Added a clarifying comment about shuffling behavior when distributed sampling is active. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Chun Cai <amoycaic@gmail.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 43e0288 commit 30b762e

File tree

4 files changed

+66
-130
lines changed

4 files changed

+66
-130
lines changed

deepmd/pt/train/training.py

Lines changed: 36 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
import functools
33
import logging
44
import time
5+
from collections.abc import (
6+
Iterable,
7+
)
58
from copy import (
69
deepcopy,
710
)
@@ -47,7 +50,6 @@
4750
dp_random,
4851
)
4952
from deepmd.pt.utils.dataloader import (
50-
BufferedIterator,
5153
get_sampler_from_params,
5254
)
5355
from deepmd.pt.utils.env import (
@@ -159,8 +161,24 @@ def get_opt_param(params):
159161
}
160162
return opt_type, opt_param
161163

164+
def cycle_iterator(iterable: Iterable):
165+
"""
166+
Produces an infinite iterator by repeatedly cycling through the given iterable.
167+
168+
Args:
169+
iterable (Iterable): The iterable to cycle through.
170+
171+
Yields
172+
------
173+
Any: The next item from the iterable, cycling back to the beginning when the end is reached.
174+
"""
175+
while True:
176+
with torch.device("cpu"):
177+
it = iter(iterable)
178+
yield from it
179+
162180
def get_data_loader(_training_data, _validation_data, _training_params):
163-
def get_dataloader_and_buffer(_data, _params):
181+
def get_dataloader_and_iter(_data, _params):
164182
_sampler = get_sampler_from_params(_data, _params)
165183
if _sampler is None:
166184
log.warning(
@@ -177,33 +195,32 @@ def get_dataloader_and_buffer(_data, _params):
177195
collate_fn=lambda batch: batch, # prevent extra conversion
178196
pin_memory=True,
179197
)
180-
with torch.device("cpu"):
181-
_data_buffered = BufferedIterator(iter(_dataloader))
182-
return _dataloader, _data_buffered
198+
_data_iter = cycle_iterator(_dataloader)
199+
return _dataloader, _data_iter
183200

184-
training_dataloader, training_data_buffered = get_dataloader_and_buffer(
201+
training_dataloader, training_data_iter = get_dataloader_and_iter(
185202
_training_data, _training_params["training_data"]
186203
)
187204

188205
if _validation_data is not None:
189206
(
190207
validation_dataloader,
191-
validation_data_buffered,
192-
) = get_dataloader_and_buffer(
208+
validation_data_iter,
209+
) = get_dataloader_and_iter(
193210
_validation_data, _training_params["validation_data"]
194211
)
195212
valid_numb_batch = _training_params["validation_data"].get(
196213
"numb_btch", 1
197214
)
198215
else:
199216
validation_dataloader = None
200-
validation_data_buffered = None
217+
validation_data_iter = None
201218
valid_numb_batch = 1
202219
return (
203220
training_dataloader,
204-
training_data_buffered,
221+
training_data_iter,
205222
validation_dataloader,
206-
validation_data_buffered,
223+
validation_data_iter,
207224
valid_numb_batch,
208225
)
209226

@@ -1064,48 +1081,15 @@ def save_model(self, save_path, lr=0.0, step=0) -> None:
10641081
checkpoint_files[0].unlink()
10651082

10661083
def get_data(self, is_train=True, task_key="Default"):
1067-
if not self.multi_task:
1068-
if is_train:
1069-
try:
1070-
batch_data = next(iter(self.training_data))
1071-
except StopIteration:
1072-
# Refresh the status of the dataloader to start from a new epoch
1073-
with torch.device("cpu"):
1074-
self.training_data = BufferedIterator(
1075-
iter(self.training_dataloader)
1076-
)
1077-
batch_data = next(iter(self.training_data))
1078-
else:
1079-
if self.validation_data is None:
1080-
return {}, {}, {}
1081-
try:
1082-
batch_data = next(iter(self.validation_data))
1083-
except StopIteration:
1084-
self.validation_data = BufferedIterator(
1085-
iter(self.validation_dataloader)
1086-
)
1087-
batch_data = next(iter(self.validation_data))
1084+
if is_train:
1085+
iterator = self.training_data
10881086
else:
1089-
if is_train:
1090-
try:
1091-
batch_data = next(iter(self.training_data[task_key]))
1092-
except StopIteration:
1093-
# Refresh the status of the dataloader to start from a new epoch
1094-
self.training_data[task_key] = BufferedIterator(
1095-
iter(self.training_dataloader[task_key])
1096-
)
1097-
batch_data = next(iter(self.training_data[task_key]))
1098-
else:
1099-
if self.validation_data[task_key] is None:
1100-
return {}, {}, {}
1101-
try:
1102-
batch_data = next(iter(self.validation_data[task_key]))
1103-
except StopIteration:
1104-
self.validation_data[task_key] = BufferedIterator(
1105-
iter(self.validation_dataloader[task_key])
1106-
)
1107-
batch_data = next(iter(self.validation_data[task_key]))
1108-
1087+
iterator = self.validation_data
1088+
if self.multi_task:
1089+
iterator = iterator[task_key]
1090+
if iterator is None:
1091+
return {}, {}, {}
1092+
batch_data = next(iterator)
11091093
for key in batch_data.keys():
11101094
if key == "sid" or key == "fid" or key == "box" or "find_" in key:
11111095
continue

deepmd/pt/utils/dataloader.py

Lines changed: 14 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,9 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22
import logging
33
import os
4-
import time
54
from multiprocessing.dummy import (
65
Pool,
76
)
8-
from queue import (
9-
Queue,
10-
)
11-
from threading import (
12-
Thread,
13-
)
147

158
import h5py
169
import numpy as np
@@ -173,7 +166,9 @@ def construct_dataset(system):
173166
num_workers=0, # Should be 0 to avoid too many threads forked
174167
sampler=system_sampler,
175168
collate_fn=collate_batch,
176-
shuffle=(not (dist.is_available() and dist.is_initialized()))
169+
shuffle=(
170+
not (dist.is_available() and dist.is_initialized())
171+
) # distributed sampler will do the shuffling by default
177172
and shuffle,
178173
)
179174
self.dataloaders.append(system_dataloader)
@@ -200,11 +195,12 @@ def __len__(self) -> int:
200195

201196
def __getitem__(self, idx):
202197
# log.warning(str(torch.distributed.get_rank())+" idx: "+str(idx)+" index: "+str(self.index[idx]))
203-
try:
204-
batch = next(self.iters[idx])
205-
except StopIteration:
206-
self.iters[idx] = iter(self.dataloaders[idx])
207-
batch = next(self.iters[idx])
198+
with torch.device("cpu"):
199+
try:
200+
batch = next(self.iters[idx])
201+
except StopIteration:
202+
self.iters[idx] = iter(self.dataloaders[idx])
203+
batch = next(self.iters[idx])
208204
batch["sid"] = idx
209205
return batch
210206

@@ -235,54 +231,6 @@ def print_summary(
235231
)
236232

237233

238-
class BackgroundConsumer(Thread):
239-
def __init__(self, queue, source) -> None:
240-
super().__init__()
241-
self.daemon = True
242-
self._queue = queue
243-
self._source = source # Main DL iterator
244-
245-
def run(self) -> None:
246-
for item in self._source:
247-
self._queue.put(item) # Blocking if the queue is full
248-
249-
# Signal the consumer we are done; this should not happen for DataLoader
250-
self._queue.put(StopIteration())
251-
252-
253-
QUEUESIZE = 32
254-
255-
256-
class BufferedIterator:
257-
def __init__(self, iterable) -> None:
258-
self._queue = Queue(QUEUESIZE)
259-
self._iterable = iterable
260-
self._consumer = BackgroundConsumer(self._queue, self._iterable)
261-
self._consumer.start()
262-
self.last_warning_time = time.time()
263-
264-
def __iter__(self):
265-
return self
266-
267-
def __len__(self) -> int:
268-
return len(self._iterable)
269-
270-
def __next__(self):
271-
start_wait = time.time()
272-
item = self._queue.get()
273-
wait_time = time.time() - start_wait
274-
if (
275-
wait_time > 1.0 and start_wait - self.last_warning_time > 15 * 60
276-
): # Even for Multi-Task training, each step usually takes < 1s
277-
log.warning(
278-
f"Data loading is slow, waited {wait_time:.2f} seconds. Ignoring this warning for 15 minutes."
279-
)
280-
self.last_warning_time = start_wait
281-
if isinstance(item, Exception):
282-
raise item
283-
return item
284-
285-
286234
def collate_batch(batch):
287235
example = batch[0]
288236
result = {}
@@ -320,7 +268,11 @@ def get_weighted_sampler(training_data, prob_style, sys_prob=False):
320268
# training_data.total_batch is the size of one epoch, you can increase it to avoid too many rebuilding of iterators
321269
len_sampler = training_data.total_batch * max(env.NUM_WORKERS, 1)
322270
with torch.device("cpu"):
323-
sampler = WeightedRandomSampler(probs, len_sampler, replacement=True)
271+
sampler = WeightedRandomSampler(
272+
probs,
273+
len_sampler,
274+
replacement=True,
275+
)
324276
return sampler
325277

326278

source/tests/pt/model/test_saveload_dpa1.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
env,
2626
)
2727
from deepmd.pt.utils.dataloader import (
28-
BufferedIterator,
2928
DpLoaderSet,
3029
)
3130
from deepmd.pt.utils.stat import (
@@ -72,8 +71,13 @@ def setUp(self) -> None:
7271
drop_last=False,
7372
pin_memory=True,
7473
)
74+
75+
def cycle_iterator(iterable):
76+
while True:
77+
yield from iterable
78+
7579
with torch.device("cpu"):
76-
self.training_data = BufferedIterator(iter(self.training_dataloader))
80+
self.training_data = cycle_iterator(self.training_dataloader)
7781
self.loss = EnergyStdLoss(**self.config["loss"])
7882
self.cur_lr = 1
7983
self.task_key = "Default"
@@ -111,12 +115,8 @@ def create_wrapper(self, read: bool):
111115
return ModelWrapper(model, self.loss)
112116

113117
def get_data(self):
114-
try:
115-
batch_data = next(iter(self.training_data))
116-
except StopIteration:
117-
# Refresh the status of the dataloader to start from a new epoch
118-
self.training_data = BufferedIterator(iter(self.training_dataloader))
119-
batch_data = next(iter(self.training_data))
118+
with torch.device("cpu"):
119+
batch_data = next(self.training_data)
120120
input_dict = {}
121121
for item in ["coord", "atype", "box"]:
122122
if item in batch_data:

source/tests/pt/model/test_saveload_se_e2_a.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
env,
2626
)
2727
from deepmd.pt.utils.dataloader import (
28-
BufferedIterator,
2928
DpLoaderSet,
3029
)
3130
from deepmd.pt.utils.stat import (
@@ -72,8 +71,13 @@ def setUp(self) -> None:
7271
drop_last=False,
7372
pin_memory=True,
7473
)
74+
75+
def cycle_iterator(iterable):
76+
while True:
77+
yield from iterable
78+
7579
with torch.device("cpu"):
76-
self.training_data = BufferedIterator(iter(self.training_dataloader))
80+
self.training_data = cycle_iterator(self.training_dataloader)
7781
self.loss = EnergyStdLoss(**self.config["loss"])
7882
self.cur_lr = 1
7983
self.task_key = "Default"
@@ -105,12 +109,8 @@ def create_wrapper(self):
105109
return ModelWrapper(model, self.loss)
106110

107111
def get_data(self):
108-
try:
109-
batch_data = next(iter(self.training_data))
110-
except StopIteration:
111-
# Refresh the status of the dataloader to start from a new epoch
112-
self.training_data = BufferedIterator(iter(self.training_dataloader))
113-
batch_data = next(iter(self.training_data))
112+
with torch.device("cpu"):
113+
batch_data = next(self.training_data)
114114
input_dict = {}
115115
for item in ["coord", "atype", "box"]:
116116
if item in batch_data:

0 commit comments

Comments
 (0)