Skip to content

Commit ca28506

Browse files
authored
revise dataloader for multisystems in 1 batch when ener_bind (#49)
1 parent ff588a8 commit ca28506

1 file changed

Lines changed: 56 additions & 4 deletions

File tree

deepmd/pt/utils/dataloader.py

Lines changed: 56 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -241,15 +241,67 @@ def __getitem__(self, idx):
241241
batch = next(self.iters[idx])
242242
else:
243243
try:
244+
# For binding datasets, we need to ensure all subsystems load the same frame
245+
# We'll get data from the first dataset (e.g. complex) first
246+
first_key = self.part_key[0]
244247
batch = {}
245-
for kk in self.part_key:
246-
batch[kk] = next(self.iters[kk][idx])
248+
249+
# Get the complex data first, which maintains its original randomness
250+
# The random selection comes from the sampler configured in the DataLoader
251+
first_key_data = next(self.iters[first_key][idx])
252+
batch[first_key] = first_key_data
253+
254+
# Now, ensure other subsystems use the same frame indices
255+
if 'fid' in first_key_data:
256+
# Get the frame indices from the complex data
257+
frame_indices = first_key_data['fid']
258+
259+
# For each other subsystem, get the same frames
260+
for kk in self.part_key[1:]:
261+
sub_batch = []
262+
# Each batch may contain multiple frames
263+
# We need to match each frame individually
264+
for i, frame_idx in enumerate(frame_indices):
265+
# Use the dataset's __getitem__ method to get the exact frame
266+
# This ensures we always get the matching frame
267+
frame = self.systems[kk][idx][frame_idx]
268+
sub_batch.append(frame)
269+
270+
# Collate the selected frames into a batch
271+
batch[kk] = collate_batch(sub_batch)
272+
else:
273+
# Fallback for cases where fid is not available
274+
# Use the original method which may have synchronization issues
275+
log.warning("Frame IDs not available, synchronization between subsystems may be imperfect")
276+
for kk in self.part_key[1:]:
277+
batch[kk] = next(self.iters[kk][idx])
247278
except StopIteration:
279+
# If complex iterator is exhausted, reset all iterators
248280
for kk in self.part_key:
249281
self.iters[kk][idx] = iter(self.dataloaders[kk][idx])
282+
283+
# Try again with reset iterators
284+
first_key = self.part_key[0]
250285
batch = {}
251-
for kk in self.part_key:
252-
batch[kk] = next(self.iters[kk][idx])
286+
first_key_data = next(self.iters[first_key][idx])
287+
batch[first_key] = first_key_data
288+
289+
# Same logic as above for getting matching frames
290+
if 'fid' in first_key_data:
291+
frame_indices = first_key_data['fid']
292+
293+
for kk in self.part_key[1:]:
294+
sub_batch = []
295+
for i, frame_idx in enumerate(frame_indices):
296+
frame = self.systems[kk][idx][frame_idx]
297+
sub_batch.append(frame)
298+
299+
batch[kk] = collate_batch(sub_batch)
300+
else:
301+
log.warning("Frame IDs not available, synchronization between subsystems may be imperfect")
302+
for kk in self.part_key[1:]:
303+
batch[kk] = next(self.iters[kk][idx])
304+
253305
batch["sid"] = idx
254306
return batch
255307

0 commit comments

Comments
 (0)