Skip to content

Commit 0738f93

Browse files
committed
set device for DpLoaderSet
1 parent d93c948 commit 0738f93

1 file changed

Lines changed: 6 additions & 5 deletions

File tree

deepmd/pt/utils/dataloader.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -195,11 +195,12 @@ def __len__(self) -> int:
195195

196196
def __getitem__(self, idx):
197197
# log.warning(str(torch.distributed.get_rank())+" idx: "+str(idx)+" index: "+str(self.index[idx]))
198-
try:
199-
batch = next(self.iters[idx])
200-
except StopIteration:
201-
self.iters[idx] = iter(self.dataloaders[idx])
202-
batch = next(self.iters[idx])
198+
with torch.device("cpu"):
199+
try:
200+
batch = next(self.iters[idx])
201+
except StopIteration:
202+
self.iters[idx] = iter(self.dataloaders[idx])
203+
batch = next(self.iters[idx])
203204
batch["sid"] = idx
204205
return batch
205206

0 commit comments

Comments
 (0)