Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 56 additions & 4 deletions deepmd/pt/utils/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,15 +241,67 @@ def __getitem__(self, idx):
batch = next(self.iters[idx])
else:
try:
# For binding datasets, we need to ensure all subsystems load the same frame
# We'll get data from the first dataset (e.g. complex) first
first_key = self.part_key[0]
batch = {}
for kk in self.part_key:
batch[kk] = next(self.iters[kk][idx])

# Get the complex data first, which maintains its original randomness
# The random selection comes from the sampler configured in the DataLoader
first_key_data = next(self.iters[first_key][idx])
batch[first_key] = first_key_data

# Now, ensure other subsystems use the same frame indices
if 'fid' in first_key_data:
# Get the frame indices from the complex data
frame_indices = first_key_data['fid']

# For each other subsystem, get the same frames
for kk in self.part_key[1:]:
sub_batch = []
# Each batch may contain multiple frames
# We need to match each frame individually
for i, frame_idx in enumerate(frame_indices):
# Use the dataset's __getitem__ method to get the exact frame
# This ensures we always get the matching frame
frame = self.systems[kk][idx][frame_idx]
sub_batch.append(frame)

# Collate the selected frames into a batch
batch[kk] = collate_batch(sub_batch)
else:
# Fallback for cases where fid is not available
# Use the original method which may have synchronization issues
log.warning("Frame IDs not available, synchronization between subsystems may be imperfect")
for kk in self.part_key[1:]:
batch[kk] = next(self.iters[kk][idx])
except StopIteration:
# If complex iterator is exhausted, reset all iterators
for kk in self.part_key:
self.iters[kk][idx] = iter(self.dataloaders[kk][idx])

# Try again with reset iterators
first_key = self.part_key[0]
batch = {}
for kk in self.part_key:
batch[kk] = next(self.iters[kk][idx])
first_key_data = next(self.iters[first_key][idx])
batch[first_key] = first_key_data

# Same logic as above for getting matching frames
if 'fid' in first_key_data:
frame_indices = first_key_data['fid']

for kk in self.part_key[1:]:
sub_batch = []
for i, frame_idx in enumerate(frame_indices):
frame = self.systems[kk][idx][frame_idx]
sub_batch.append(frame)

batch[kk] = collate_batch(sub_batch)
else:
log.warning("Frame IDs not available, synchronization between subsystems may be imperfect")
for kk in self.part_key[1:]:
batch[kk] = next(self.iters[kk][idx])

batch["sid"] = idx
return batch

Expand Down
Loading