@@ -23,16 +23,22 @@ class DummyDataloader:
2323
2424 def __init__ (self , dataset ):
2525 """
26- param dataset: The dataset object to be processed.
27- :notes:
28- - **Distributed Environment**:
29- - Divides the dataset across processes using the
30- rank and world size.
31- - Fetches only the portion of data corresponding to
32- the current process.
33- - **Non-Distributed Environment**:
34- - Fetches the entire dataset.
26+ Preprare a dataloader object which will return the entire dataset
27+ in a single batch. Depending on the number of GPUs, the dataset we
28+ have the following cases:
29+
30+ - **Distributed Environment** (multiple GPUs):
31+ - Divides the dataset across processes using the rank and world size.
32+ - Fetches only the portion of data corresponding to the current process.
33+ - **Non-Distributed Environment** (single GPU):
34+ - Fetches the entire dataset.
35+
36+ :param dataset: The dataset object to be processed.
37+ :type dataset: PinaDataset
38+
39+ .. note:: This data loader is used when the batch size is None.
3540 """
41+
3642 if (
3743 torch .distributed .is_available ()
3844 and torch .distributed .is_initialized ()
@@ -67,23 +73,50 @@ class Collator:
6773 Class used to collate the batch
6874 """
6975
70- def __init__ (self , max_conditions_lengths , dataset = None ):
76+ def __init__ (
77+ self , max_conditions_lengths , automatic_batching , dataset = None
78+ ):
79+ """
80+ Initialize the object, setting the collate function based on whether
81+ automatic batching is enabled or not.
82+
83+ :param dict max_conditions_lengths: dict containing the maximum number of
84+ data points to consider in a single batch for each condition.
85+ :param PinaDataset dataset: The dataset where the data is stored.
86+ """
87+
7188 self .max_conditions_lengths = max_conditions_lengths
89+ # Set the collate function based on the batching strategy
90+ # collate_pina_dataloader is used when automatic batching is disabled
91+ # collate_torch_dataloader is used when automatic batching is enabled
7292 self .callable_function = (
73- self ._collate_custom_dataloader
74- if max_conditions_lengths is None
75- else (self ._collate_standard_dataloader )
93+ self ._collate_torch_dataloader
94+ if automatic_batching
95+ else (self ._collate_pina_dataloader )
7696 )
7797 self .dataset = dataset
98+
99+ # Set the function which performs the actual collation
78100 if isinstance (self .dataset , PinaTensorDataset ):
101+ # If the dataset is a PinaTensorDataset, use this collate function
79102 self ._collate = self ._collate_tensor_dataset
80103 else :
104+ # If the dataset is a PinaDataset, use this collate function
81105 self ._collate = self ._collate_graph_dataset
82106
83- def _collate_custom_dataloader (self , batch ):
107+ def _collate_pina_dataloader (self , batch ):
108+ """
109+ Function used to create a batch when automatic batching is disabled.
110+
111+ :param list(int) batch: List of integers representing the indices of
112+ the data points to be fetched.
113+ :return: Dictionary containing the data points fetched from the dataset.
114+ :rtype: dict
115+ """
116+ # Call the fetch_from_idx_list method of the dataset
84117 return self .dataset .fetch_from_idx_list (batch )
85118
86- def _collate_standard_dataloader (self , batch ):
119+ def _collate_torch_dataloader (self , batch ):
87120 """
88121 Function used to collate the batch
89122 """
@@ -112,22 +145,56 @@ def _collate_standard_dataloader(self, batch):
112145
113146 @staticmethod
114147 def _collate_tensor_dataset (data_list ):
148+ """
149+ Function used to collate the data when the dataset is a
150+ `PinaTensorDataset`.
151+
152+ :param data_list: List of `torch.Tensor` or `LabelTensor` to be
153+ collated.
154+ :type data_list: list(torch.Tensor) | list(LabelTensor)
155+ :raises RuntimeError: If the data is not a `torch.Tensor` or a
156+ `LabelTensor`.
157+ :return: Batch of data
158+ :rtype: dict
159+ """
160+
115161 if isinstance (data_list [0 ], LabelTensor ):
116162 return LabelTensor .stack (data_list )
117163 if isinstance (data_list [0 ], torch .Tensor ):
118164 return torch .stack (data_list )
119165 raise RuntimeError ("Data must be Tensors or LabelTensor " )
120166
121167 def _collate_graph_dataset (self , data_list ):
168+ """
169+ Function used to collate the data when the dataset is a
170+ `PinaGraphDataset`.
171+
172+ :param data_list: List of `Data` or `Graph` to be collated.
173+ :type data_list: list(Data) | list(Graph)
174+ :raises RuntimeError: If the data is not a `Data` or a `Graph`.
175+ :return: Batch of data
176+ :rtype: dict
177+ """
178+
122179 if isinstance (data_list [0 ], LabelTensor ):
123180 return LabelTensor .cat (data_list )
124181 if isinstance (data_list [0 ], torch .Tensor ):
125182 return torch .cat (data_list )
126183 if isinstance (data_list [0 ], Data ):
127- return self .dataset .create_graph_batch (data_list )
184+ return self .dataset .create_batch (data_list )
128185 raise RuntimeError ("Data must be Tensors or LabelTensor or pyG Data" )
129186
130187 def __call__ (self , batch ):
188+ """
189+ Call the function to collate the batch, defined in __init__.
190+
191+ :param batch: list of indices or list of retrieved data
192+ :type batch: list(int) | list(dict)
193+ :return: Dictionary containing the data points fetched from the dataset,
194+ collated.
195+ :rtype: dict
196+ """
197+
131198 return self .callable_function (batch )
132199
133200
@@ -137,6 +204,16 @@ class PinaSampler:
137204 """
138205
139206 def __new__ (cls , dataset , shuffle ):
207+ """
208+ Create the sampler instance, according to shuffle and whether the
209+ environment is distributed or not.
210+
211+ :param PinaDataset dataset: The dataset to be sampled.
212+ :param bool shuffle: whether to shuffle the dataset or not before
213+ sampling.
214+ :return: The sampler instance.
215+ :rtype: torch.utils.data.Sampler
216+ """
140217
141218 if (
142219 torch .distributed .is_available ()
@@ -173,29 +250,24 @@ def __init__(
173250 """
174251 Initialize the object, creating datasets based on the input problem.
175252
176- :param problem: The problem defining the dataset.
177- :type problem: AbstractProblem
178- :param train_size: Fraction or number of elements in the training split.
179- :type train_size: float
180- :param test_size: Fraction or number of elements in the test split.
181- :type test_size: float
182- :param val_size: Fraction or number of elements in the validation split.
183- :type val_size: float
184- :param batch_size: Batch size used for training. If None, the entire
185- dataset is used per batch.
186- :type batch_size: int or None
187- :param shuffle: Whether to shuffle the dataset before splitting.
188- :type shuffle: bool
189- :param repeat: Whether to repeat the dataset indefinitely.
190- :type repeat: bool
253+ :param AbstractProblem problem: The problem containing the data on which
254+ to train/test the model.
255+ :param float train_size: Fraction or number of elements in the training
256+ split.
257+ :param float test_size: Fraction or number of elements in the test
258+ split.
259+ :param float val_size: Fraction or number of elements in the validation
260+ split.
261+ :param batch_size: The batch size used for training. If `None`, the
262+ entire dataset is used per batch.
263+ :type batch_size: int | None
264+ :param bool shuffle: Whether to shuffle the dataset before splitting.
265+ :param bool repeat: Whether to repeat the dataset indefinitely.
191266 :param automatic_batching: Whether to enable automatic batching.
192- :type automatic_batching: bool
193- :param num_workers: Number of worker threads for data loading.
267+ :param int num_workers: Number of worker threads for data loading.
194268 Default 0 (serial loading)
195- :type num_workers: int
196- :param pin_memory: Whether to use pinned memory for faster data
269+ :param bool pin_memory: Whether to use pinned memory for faster data
197270 transfer to GPU. (Default False)
198- :type pin_memory: bool
199271 """
200272 super ().__init__ ()
201273
@@ -365,10 +437,14 @@ def _create_dataloader(self, split, dataset):
365437 sampler = PinaSampler (dataset , shuffle )
366438 if self .automatic_batching :
367439 collate = Collator (
368- self .find_max_conditions_lengths (split ), dataset = dataset
440+ self .find_max_conditions_lengths (split ),
441+ self .automatic_batching ,
442+ dataset = dataset ,
369443 )
370444 else :
371- collate = Collator (None , dataset = dataset )
445+ collate = Collator (
446+ None , self .automatic_batching , dataset = dataset
447+ )
372448 return DataLoader (
373449 dataset ,
374450 self .batch_size ,
@@ -413,23 +489,51 @@ def val_dataloader(self):
413489 def train_dataloader (self ):
414490 """
415491 Create the training dataloader
492+
493+ :return: The training dataloader
494+ :rtype: DataLoader
416495 """
417496 return self ._create_dataloader ("train" , self .train_dataset )
418497
419498 def test_dataloader (self ):
420499 """
421500 Create the testing dataloader
501+
502+ :return: The testing dataloader
503+ :rtype: DataLoader
422504 """
423505 return self ._create_dataloader ("test" , self .test_dataset )
424506
425507 @staticmethod
426508 def _transfer_batch_to_device_dummy (batch , device , dataloader_idx ):
509+ """
510+ Transfer the batch to the device. This method is called in the
511+ training loop and is used to transfer the batch to the device.
512+ This method is used when the batch size is None: batch has already
513+ been transferred to the device.
514+
515+ :param list(tuple) batch: list of tuple where the first element of the
516+ tuple is the condition name and the second element is the data.
517+ :param device: device to which the batch is transferred
518+ :type device: torch.device
519+ :param dataloader_idx: index of the dataloader
520+ :type dataloader_idx: int
521+ :return: The batch transferred to the device.
522+ :rtype: list(tuple)
523+ """
427524 return batch
428525
429526 def _transfer_batch_to_device (self , batch , device , dataloader_idx ):
430527 """
431528 Transfer the batch to the device. This method is called in the
432529 training loop and is used to transfer the batch to the device.
530+
531+ :param dict batch: The batch to be transferred to the device.
532+ :param device: The device to which the batch is transferred.
533+ :type device: torch.device
534+ :param int dataloader_idx: The index of the dataloader.
535+ :return: The batch transferred to the device.
536+ :rtype: list(tuple)
433537 """
434538 batch = [
435539 (
@@ -456,13 +560,16 @@ def _check_slit_sizes(train_size, test_size, val_size):
456560 @property
457561 def input (self ):
458562 """
459- # TODO
563+ Return all the input points coming from all the datasets.
564+
565+ :return: The input points for training.
566+ :rtype dict
460567 """
461568 to_return = {}
462569 if hasattr (self , "train_dataset" ) and self .train_dataset is not None :
463570 to_return ["train" ] = self .train_dataset .input
464571 if hasattr (self , "val_dataset" ) and self .val_dataset is not None :
465572 to_return ["val" ] = self .val_dataset .input
466573 if hasattr (self , "test_dataset" ) and self .test_dataset is not None :
467- to_return = self .test_dataset .input
574+ to_return [ "test" ] = self .test_dataset .input
468575 return to_return
0 commit comments