diff --git a/deepmd/pt/utils/dataloader.py b/deepmd/pt/utils/dataloader.py index ec1aa68128..1f09f41320 100644 --- a/deepmd/pt/utils/dataloader.py +++ b/deepmd/pt/utils/dataloader.py @@ -241,15 +241,67 @@ def __getitem__(self, idx): batch = next(self.iters[idx]) else: try: + # For binding datasets, we need to ensure all subsystems load the same frame + # We'll get data from the first dataset (e.g. complex) first + first_key = self.part_key[0] batch = {} - for kk in self.part_key: - batch[kk] = next(self.iters[kk][idx]) + + # Get the complex data first, which maintains its original randomness + # The random selection comes from the sampler configured in the DataLoader + first_key_data = next(self.iters[first_key][idx]) + batch[first_key] = first_key_data + + # Now, ensure other subsystems use the same frame indices + if 'fid' in first_key_data: + # Get the frame indices from the complex data + frame_indices = first_key_data['fid'] + + # For each other subsystem, get the same frames + for kk in self.part_key[1:]: + sub_batch = [] + # Each batch may contain multiple frames + # We need to match each frame individually + for i, frame_idx in enumerate(frame_indices): + # Use the dataset's __getitem__ method to get the exact frame + # This ensures we always get the matching frame + frame = self.systems[kk][idx][frame_idx] + sub_batch.append(frame) + + # Collate the selected frames into a batch + batch[kk] = collate_batch(sub_batch) + else: + # Fallback for cases where fid is not available + # Use the original method which may have synchronization issues + log.warning("Frame IDs not available, synchronization between subsystems may be imperfect") + for kk in self.part_key[1:]: + batch[kk] = next(self.iters[kk][idx]) except StopIteration: + # If complex iterator is exhausted, reset all iterators for kk in self.part_key: self.iters[kk][idx] = iter(self.dataloaders[kk][idx]) + + # Try again with reset iterators + first_key = self.part_key[0] batch = {} - for kk in self.part_key: - batch[kk] = next(self.iters[kk][idx]) + first_key_data = next(self.iters[first_key][idx]) + batch[first_key] = first_key_data + + # Same logic as above for getting matching frames + if 'fid' in first_key_data: + frame_indices = first_key_data['fid'] + + for kk in self.part_key[1:]: + sub_batch = [] + for i, frame_idx in enumerate(frame_indices): + frame = self.systems[kk][idx][frame_idx] + sub_batch.append(frame) + + batch[kk] = collate_batch(sub_batch) + else: + log.warning("Frame IDs not available, synchronization between subsystems may be imperfect") + for kk in self.part_key[1:]: + batch[kk] = next(self.iters[kk][idx]) + batch["sid"] = idx return batch