1212from torch .utils .data .distributed import DistributedSampler
1313from ..label_tensor import LabelTensor
1414from .dataset import PinaDatasetFactory , PinaTensorDataset
15- from ..collector import Collector
15+
16+ # from ..collector import Collector
1617
1718
1819class DummyDataloader :
@@ -330,9 +331,10 @@ def __init__(
330331 self .pin_memory = pin_memory
331332
332333 # Collect data
333- collector = Collector (problem )
334- collector .store_fixed_data ()
335- collector .store_sample_domains ()
334+ # collector = Collector(problem)
335+ # collector.store_fixed_data()
336+ # collector.store_sample_domains()
337+ problem .aggregate_data ()
336338
337339 # Check if the splits are correct
338340 self ._check_slit_sizes (train_size , test_size , val_size )
@@ -361,7 +363,7 @@ def __init__(
361363 # raises NotImplementedError
362364 self .val_dataloader = super ().val_dataloader
363365
364- self .collector_splits = self ._create_splits (collector , splits_dict )
366+ self .data_splits = self ._create_splits (problem . data , splits_dict )
365367 self .transfer_batch_to_device = self ._transfer_batch_to_device
366368
367369 def setup (self , stage = None ):
@@ -376,23 +378,23 @@ def setup(self, stage=None):
376378 """
377379 if stage == "fit" or stage is None :
378380 self .train_dataset = PinaDatasetFactory (
379- self .collector_splits ["train" ],
381+ self .data_splits ["train" ],
380382 max_conditions_lengths = self .find_max_conditions_lengths (
381383 "train"
382384 ),
383385 automatic_batching = self .automatic_batching ,
384386 )
385- if "val" in self .collector_splits .keys ():
387+ if "val" in self .data_splits .keys ():
386388 self .val_dataset = PinaDatasetFactory (
387- self .collector_splits ["val" ],
389+ self .data_splits ["val" ],
388390 max_conditions_lengths = self .find_max_conditions_lengths (
389391 "val"
390392 ),
391393 automatic_batching = self .automatic_batching ,
392394 )
393395 elif stage == "test" :
394396 self .test_dataset = PinaDatasetFactory (
395- self .collector_splits ["test" ],
397+ self .data_splits ["test" ],
396398 max_conditions_lengths = self .find_max_conditions_lengths ("test" ),
397399 automatic_batching = self .automatic_batching ,
398400 )
@@ -473,7 +475,7 @@ def _apply_shuffle(condition_dict, len_data):
473475 for (
474476 condition_name ,
475477 condition_dict ,
476- ) in collector .data_collections . items ():
478+ ) in collector .items ():
477479 len_data = len (condition_dict ["input" ])
478480 if self .shuffle :
479481 _apply_shuffle (condition_dict , len_data )
@@ -540,7 +542,7 @@ def find_max_conditions_lengths(self, split):
540542 """
541543
542544 max_conditions_lengths = {}
543- for k , v in self .collector_splits [split ].items ():
545+ for k , v in self .data_splits [split ].items ():
544546 if self .batch_size is None :
545547 max_conditions_lengths [k ] = len (v ["input" ])
546548 elif self .repeat :
0 commit comments