@@ -241,15 +241,67 @@ def __getitem__(self, idx):
241241 batch = next (self .iters [idx ])
242242 else :
243243 try :
244+ # For binding datasets, we need to ensure all subsystems load the same frame
245+ # We'll get data from the first dataset (e.g. complex) first
246+ first_key = self .part_key [0 ]
244247 batch = {}
245- for kk in self .part_key :
246- batch [kk ] = next (self .iters [kk ][idx ])
248+
249+ # Get the complex data first, which maintains its original randomness
250+ # The random selection comes from the sampler configured in the DataLoader
251+ first_key_data = next (self .iters [first_key ][idx ])
252+ batch [first_key ] = first_key_data
253+
254+ # Now, ensure other subsystems use the same frame indices
255+ if 'fid' in first_key_data :
256+ # Get the frame indices from the complex data
257+ frame_indices = first_key_data ['fid' ]
258+
259+ # For each other subsystem, get the same frames
260+ for kk in self .part_key [1 :]:
261+ sub_batch = []
262+ # Each batch may contain multiple frames
263+ # We need to match each frame individually
264+ for i , frame_idx in enumerate (frame_indices ):
265+ # Use the dataset's __getitem__ method to get the exact frame
266+ # This ensures we always get the matching frame
267+ frame = self .systems [kk ][idx ][frame_idx ]
268+ sub_batch .append (frame )
269+
270+ # Collate the selected frames into a batch
271+ batch [kk ] = collate_batch (sub_batch )
272+ else :
273+ # Fallback for cases where fid is not available
274+ # Use the original method which may have synchronization issues
275+ log .warning ("Frame IDs not available, synchronization between subsystems may be imperfect" )
276+ for kk in self .part_key [1 :]:
277+ batch [kk ] = next (self .iters [kk ][idx ])
247278 except StopIteration :
279+ # If complex iterator is exhausted, reset all iterators
248280 for kk in self .part_key :
249281 self .iters [kk ][idx ] = iter (self .dataloaders [kk ][idx ])
282+
283+ # Try again with reset iterators
284+ first_key = self .part_key [0 ]
250285 batch = {}
251- for kk in self .part_key :
252- batch [kk ] = next (self .iters [kk ][idx ])
286+ first_key_data = next (self .iters [first_key ][idx ])
287+ batch [first_key ] = first_key_data
288+
289+ # Same logic as above for getting matching frames
290+ if 'fid' in first_key_data :
291+ frame_indices = first_key_data ['fid' ]
292+
293+ for kk in self .part_key [1 :]:
294+ sub_batch = []
295+ for i , frame_idx in enumerate (frame_indices ):
296+ frame = self .systems [kk ][idx ][frame_idx ]
297+ sub_batch .append (frame )
298+
299+ batch [kk ] = collate_batch (sub_batch )
300+ else :
301+ log .warning ("Frame IDs not available, synchronization between subsystems may be imperfect" )
302+ for kk in self .part_key [1 :]:
303+ batch [kk ] = next (self .iters [kk ][idx ])
304+
253305 batch ["sid" ] = idx
254306 return batch
255307
0 commit comments