Skip to content

Commit 5d5c8e1

Browse files
committed
fix ut
1 parent 171c510 commit 5d5c8e1

2 files changed

Lines changed: 11 additions & 5 deletions

File tree

deepmd/pt/train/training.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -173,9 +173,7 @@ def cycle_iterator(iterable: Iterable):
173173
Any: The next item from the iterable, cycling back to the beginning when the end is reached.
174174
"""
175175
while True:
176-
with torch.device("cpu"):
177-
it = iter(iterable)
178-
yield from it
176+
yield from iterable
179177

180178
def get_data_loader(_training_data, _validation_data, _training_params):
181179
def get_dataloader_and_iter(_data, _params):
@@ -194,7 +192,6 @@ def get_dataloader_and_iter(_data, _params):
194192
drop_last=False,
195193
collate_fn=lambda batch: batch, # prevent extra conversion
196194
pin_memory=True,
197-
pin_memory_device=str(object=DEVICE),
198195
)
199196
_data_iter = cycle_iterator(_dataloader)
200197
return _dataloader, _data_iter

deepmd/pt/utils/dataloader.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,16 @@ def get_weighted_sampler(training_data, prob_style, sys_prob=False):
267267
# training_data.total_batch is the size of one epoch, you can increase it to avoid too many rebuilding of iterators
268268
len_sampler = training_data.total_batch * max(env.NUM_WORKERS, 1)
269269
with torch.device("cpu"):
270-
sampler = WeightedRandomSampler(probs, len_sampler, replacement=True)
270+
sampler = WeightedRandomSampler(
271+
probs,
272+
len_sampler,
273+
replacement=True,
274+
generator=torch.Generator(),
275+
# If we are not setting the generator here, the random state will be initialized in the beginning of each new epoch.
276+
# This operation involves creating a new tensor for seeding on the default device,
277+
# while unit tests requires specifying the device explicitly.
278+
# https://github.com/pytorch/pytorch/blob/134179474539648ba7dee1317959529fbd0e7f89/torch/utils/data/sampler.py#L170
279+
)
271280
return sampler
272281

273282

0 commit comments

Comments
 (0)