Skip to content

Commit b67222b

Browse files
authored
Merge branch 'devel' into D0516_dynamic_sel
2 parents a63513e + 8176173 commit b67222b

5 files changed

Lines changed: 67 additions & 131 deletions

File tree

deepmd/pt/model/network/network.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ def forward(self, atype):
286286
type_embedding:
287287
288288
"""
289-
return self.embedding(atype.device)[atype]
289+
return torch.embedding(self.embedding(atype.device), atype)
290290

291291
def get_full_embedding(self, device: torch.device):
292292
"""

deepmd/pt/train/training.py

Lines changed: 36 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
import functools
33
import logging
44
import time
5+
from collections.abc import (
6+
Iterable,
7+
)
58
from copy import (
69
deepcopy,
710
)
@@ -47,7 +50,6 @@
4750
dp_random,
4851
)
4952
from deepmd.pt.utils.dataloader import (
50-
BufferedIterator,
5153
get_sampler_from_params,
5254
)
5355
from deepmd.pt.utils.env import (
@@ -159,8 +161,24 @@ def get_opt_param(params):
159161
}
160162
return opt_type, opt_param
161163

164+
def cycle_iterator(iterable: Iterable):
165+
"""
166+
Produces an infinite iterator by repeatedly cycling through the given iterable.
167+
168+
Args:
169+
iterable (Iterable): The iterable to cycle through.
170+
171+
Yields
172+
------
173+
Any: The next item from the iterable, cycling back to the beginning when the end is reached.
174+
"""
175+
while True:
176+
with torch.device("cpu"):
177+
it = iter(iterable)
178+
yield from it
179+
162180
def get_data_loader(_training_data, _validation_data, _training_params):
163-
def get_dataloader_and_buffer(_data, _params):
181+
def get_dataloader_and_iter(_data, _params):
164182
_sampler = get_sampler_from_params(_data, _params)
165183
if _sampler is None:
166184
log.warning(
@@ -177,33 +195,32 @@ def get_dataloader_and_buffer(_data, _params):
177195
collate_fn=lambda batch: batch, # prevent extra conversion
178196
pin_memory=True,
179197
)
180-
with torch.device("cpu"):
181-
_data_buffered = BufferedIterator(iter(_dataloader))
182-
return _dataloader, _data_buffered
198+
_data_iter = cycle_iterator(_dataloader)
199+
return _dataloader, _data_iter
183200

184-
training_dataloader, training_data_buffered = get_dataloader_and_buffer(
201+
training_dataloader, training_data_iter = get_dataloader_and_iter(
185202
_training_data, _training_params["training_data"]
186203
)
187204

188205
if _validation_data is not None:
189206
(
190207
validation_dataloader,
191-
validation_data_buffered,
192-
) = get_dataloader_and_buffer(
208+
validation_data_iter,
209+
) = get_dataloader_and_iter(
193210
_validation_data, _training_params["validation_data"]
194211
)
195212
valid_numb_batch = _training_params["validation_data"].get(
196213
"numb_btch", 1
197214
)
198215
else:
199216
validation_dataloader = None
200-
validation_data_buffered = None
217+
validation_data_iter = None
201218
valid_numb_batch = 1
202219
return (
203220
training_dataloader,
204-
training_data_buffered,
221+
training_data_iter,
205222
validation_dataloader,
206-
validation_data_buffered,
223+
validation_data_iter,
207224
valid_numb_batch,
208225
)
209226

@@ -1064,48 +1081,15 @@ def save_model(self, save_path, lr=0.0, step=0) -> None:
10641081
checkpoint_files[0].unlink()
10651082

10661083
def get_data(self, is_train=True, task_key="Default"):
1067-
if not self.multi_task:
1068-
if is_train:
1069-
try:
1070-
batch_data = next(iter(self.training_data))
1071-
except StopIteration:
1072-
# Refresh the status of the dataloader to start from a new epoch
1073-
with torch.device("cpu"):
1074-
self.training_data = BufferedIterator(
1075-
iter(self.training_dataloader)
1076-
)
1077-
batch_data = next(iter(self.training_data))
1078-
else:
1079-
if self.validation_data is None:
1080-
return {}, {}, {}
1081-
try:
1082-
batch_data = next(iter(self.validation_data))
1083-
except StopIteration:
1084-
self.validation_data = BufferedIterator(
1085-
iter(self.validation_dataloader)
1086-
)
1087-
batch_data = next(iter(self.validation_data))
1084+
if is_train:
1085+
iterator = self.training_data
10881086
else:
1089-
if is_train:
1090-
try:
1091-
batch_data = next(iter(self.training_data[task_key]))
1092-
except StopIteration:
1093-
# Refresh the status of the dataloader to start from a new epoch
1094-
self.training_data[task_key] = BufferedIterator(
1095-
iter(self.training_dataloader[task_key])
1096-
)
1097-
batch_data = next(iter(self.training_data[task_key]))
1098-
else:
1099-
if self.validation_data[task_key] is None:
1100-
return {}, {}, {}
1101-
try:
1102-
batch_data = next(iter(self.validation_data[task_key]))
1103-
except StopIteration:
1104-
self.validation_data[task_key] = BufferedIterator(
1105-
iter(self.validation_dataloader[task_key])
1106-
)
1107-
batch_data = next(iter(self.validation_data[task_key]))
1108-
1087+
iterator = self.validation_data
1088+
if self.multi_task:
1089+
iterator = iterator[task_key]
1090+
if iterator is None:
1091+
return {}, {}, {}
1092+
batch_data = next(iterator)
11091093
for key in batch_data.keys():
11101094
if key == "sid" or key == "fid" or key == "box" or "find_" in key:
11111095
continue

deepmd/pt/utils/dataloader.py

Lines changed: 14 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,9 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22
import logging
33
import os
4-
import time
54
from multiprocessing.dummy import (
65
Pool,
76
)
8-
from queue import (
9-
Queue,
10-
)
11-
from threading import (
12-
Thread,
13-
)
147

158
import h5py
169
import numpy as np
@@ -173,7 +166,9 @@ def construct_dataset(system):
173166
num_workers=0, # Should be 0 to avoid too many threads forked
174167
sampler=system_sampler,
175168
collate_fn=collate_batch,
176-
shuffle=(not (dist.is_available() and dist.is_initialized()))
169+
shuffle=(
170+
not (dist.is_available() and dist.is_initialized())
171+
) # distributed sampler will do the shuffling by default
177172
and shuffle,
178173
)
179174
self.dataloaders.append(system_dataloader)
@@ -200,11 +195,12 @@ def __len__(self) -> int:
200195

201196
def __getitem__(self, idx):
202197
# log.warning(str(torch.distributed.get_rank())+" idx: "+str(idx)+" index: "+str(self.index[idx]))
203-
try:
204-
batch = next(self.iters[idx])
205-
except StopIteration:
206-
self.iters[idx] = iter(self.dataloaders[idx])
207-
batch = next(self.iters[idx])
198+
with torch.device("cpu"):
199+
try:
200+
batch = next(self.iters[idx])
201+
except StopIteration:
202+
self.iters[idx] = iter(self.dataloaders[idx])
203+
batch = next(self.iters[idx])
208204
batch["sid"] = idx
209205
return batch
210206

@@ -235,54 +231,6 @@ def print_summary(
235231
)
236232

237233

238-
class BackgroundConsumer(Thread):
239-
def __init__(self, queue, source) -> None:
240-
super().__init__()
241-
self.daemon = True
242-
self._queue = queue
243-
self._source = source # Main DL iterator
244-
245-
def run(self) -> None:
246-
for item in self._source:
247-
self._queue.put(item) # Blocking if the queue is full
248-
249-
# Signal the consumer we are done; this should not happen for DataLoader
250-
self._queue.put(StopIteration())
251-
252-
253-
QUEUESIZE = 32
254-
255-
256-
class BufferedIterator:
257-
def __init__(self, iterable) -> None:
258-
self._queue = Queue(QUEUESIZE)
259-
self._iterable = iterable
260-
self._consumer = BackgroundConsumer(self._queue, self._iterable)
261-
self._consumer.start()
262-
self.last_warning_time = time.time()
263-
264-
def __iter__(self):
265-
return self
266-
267-
def __len__(self) -> int:
268-
return len(self._iterable)
269-
270-
def __next__(self):
271-
start_wait = time.time()
272-
item = self._queue.get()
273-
wait_time = time.time() - start_wait
274-
if (
275-
wait_time > 1.0 and start_wait - self.last_warning_time > 15 * 60
276-
): # Even for Multi-Task training, each step usually takes < 1s
277-
log.warning(
278-
f"Data loading is slow, waited {wait_time:.2f} seconds. Ignoring this warning for 15 minutes."
279-
)
280-
self.last_warning_time = start_wait
281-
if isinstance(item, Exception):
282-
raise item
283-
return item
284-
285-
286234
def collate_batch(batch):
287235
example = batch[0]
288236
result = {}
@@ -320,7 +268,11 @@ def get_weighted_sampler(training_data, prob_style, sys_prob=False):
320268
# training_data.total_batch is the size of one epoch, you can increase it to avoid too many rebuilding of iterators
321269
len_sampler = training_data.total_batch * max(env.NUM_WORKERS, 1)
322270
with torch.device("cpu"):
323-
sampler = WeightedRandomSampler(probs, len_sampler, replacement=True)
271+
sampler = WeightedRandomSampler(
272+
probs,
273+
len_sampler,
274+
replacement=True,
275+
)
324276
return sampler
325277

326278

source/tests/pt/model/test_saveload_dpa1.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
env,
2626
)
2727
from deepmd.pt.utils.dataloader import (
28-
BufferedIterator,
2928
DpLoaderSet,
3029
)
3130
from deepmd.pt.utils.stat import (
@@ -72,8 +71,13 @@ def setUp(self) -> None:
7271
drop_last=False,
7372
pin_memory=True,
7473
)
74+
75+
def cycle_iterator(iterable):
76+
while True:
77+
yield from iterable
78+
7579
with torch.device("cpu"):
76-
self.training_data = BufferedIterator(iter(self.training_dataloader))
80+
self.training_data = cycle_iterator(self.training_dataloader)
7781
self.loss = EnergyStdLoss(**self.config["loss"])
7882
self.cur_lr = 1
7983
self.task_key = "Default"
@@ -111,12 +115,8 @@ def create_wrapper(self, read: bool):
111115
return ModelWrapper(model, self.loss)
112116

113117
def get_data(self):
114-
try:
115-
batch_data = next(iter(self.training_data))
116-
except StopIteration:
117-
# Refresh the status of the dataloader to start from a new epoch
118-
self.training_data = BufferedIterator(iter(self.training_dataloader))
119-
batch_data = next(iter(self.training_data))
118+
with torch.device("cpu"):
119+
batch_data = next(self.training_data)
120120
input_dict = {}
121121
for item in ["coord", "atype", "box"]:
122122
if item in batch_data:

source/tests/pt/model/test_saveload_se_e2_a.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
env,
2626
)
2727
from deepmd.pt.utils.dataloader import (
28-
BufferedIterator,
2928
DpLoaderSet,
3029
)
3130
from deepmd.pt.utils.stat import (
@@ -72,8 +71,13 @@ def setUp(self) -> None:
7271
drop_last=False,
7372
pin_memory=True,
7473
)
74+
75+
def cycle_iterator(iterable):
76+
while True:
77+
yield from iterable
78+
7579
with torch.device("cpu"):
76-
self.training_data = BufferedIterator(iter(self.training_dataloader))
80+
self.training_data = cycle_iterator(self.training_dataloader)
7781
self.loss = EnergyStdLoss(**self.config["loss"])
7882
self.cur_lr = 1
7983
self.task_key = "Default"
@@ -105,12 +109,8 @@ def create_wrapper(self):
105109
return ModelWrapper(model, self.loss)
106110

107111
def get_data(self):
108-
try:
109-
batch_data = next(iter(self.training_data))
110-
except StopIteration:
111-
# Refresh the status of the dataloader to start from a new epoch
112-
self.training_data = BufferedIterator(iter(self.training_dataloader))
113-
batch_data = next(iter(self.training_data))
112+
with torch.device("cpu"):
113+
batch_data = next(self.training_data)
114114
input_dict = {}
115115
for item in ["coord", "atype", "box"]:
116116
if item in batch_data:

0 commit comments

Comments
 (0)