We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent d93c948 commit 0738f93Copy full SHA for 0738f93
1 file changed
deepmd/pt/utils/dataloader.py
@@ -195,11 +195,12 @@ def __len__(self) -> int:
195
196
def __getitem__(self, idx):
197
# log.warning(str(torch.distributed.get_rank())+" idx: "+str(idx)+" index: "+str(self.index[idx]))
198
- try:
199
- batch = next(self.iters[idx])
200
- except StopIteration:
201
- self.iters[idx] = iter(self.dataloaders[idx])
202
+ with torch.device("cpu"):
+ try:
+ batch = next(self.iters[idx])
+ except StopIteration:
+ self.iters[idx] = iter(self.dataloaders[idx])
203
204
batch["sid"] = idx
205
return batch
206
0 commit comments