@@ -23,16 +23,24 @@ class DummyDataloader:
2323
2424 def __init__ (self , dataset ):
2525 """
26- param dataset: The dataset object to be processed.
27- :notes:
28- - **Distributed Environment**:
29- - Divides the dataset across processes using the
30- rank and world size.
31- - Fetches only the portion of data corresponding to
32- the current process.
33- - **Non-Distributed Environment**:
34- - Fetches the entire dataset.
26+ Preprare a dataloader object which will return the entire dataset
27+ in a single batch. Depending on the number of GPUs, the dataset we
28+ have the following cases:
29+
30+ - **Distributed Environment** (multiple GPUs):
31+ - Divides the dataset across processes using the rank and world
32+ size.
33+ - Fetches only the portion of data corresponding to the current
34+ process.
35+ - **Non-Distributed Environment** (single GPU):
36+ - Fetches the entire dataset.
37+
38+ :param dataset: The dataset object to be processed.
39+ :type dataset: PinaDataset
40+
41+ .. note:: This data loader is used when the batch size is None.
3542 """
43+
3644 if (
3745 torch .distributed .is_available ()
3846 and torch .distributed .is_initialized ()
@@ -67,23 +75,50 @@ class Collator:
6775 Class used to collate the batch
6876 """
6977
70- def __init__ (self , max_conditions_lengths , dataset = None ):
78+ def __init__ (
79+ self , max_conditions_lengths , automatic_batching , dataset = None
80+ ):
81+ """
82+ Initialize the object, setting the collate function based on whether
83+ automatic batching is enabled or not.
84+
85+ :param dict max_conditions_lengths: dict containing the maximum number
86+ of data points to consider in a single batch for each condition.
87+ :param PinaDataset dataset: The dataset where the data is stored.
88+ """
89+
7190 self .max_conditions_lengths = max_conditions_lengths
91+ # Set the collate function based on the batching strategy
92+ # collate_pina_dataloader is used when automatic batching is disabled
93+ # collate_torch_dataloader is used when automatic batching is enabled
7294 self .callable_function = (
73- self ._collate_custom_dataloader
74- if max_conditions_lengths is None
75- else (self ._collate_standard_dataloader )
95+ self ._collate_torch_dataloader
96+ if automatic_batching
97+ else (self ._collate_pina_dataloader )
7698 )
7799 self .dataset = dataset
100+
101+ # Set the function which performs the actual collation
78102 if isinstance (self .dataset , PinaTensorDataset ):
103+ # If the dataset is a PinaTensorDataset, use this collate function
79104 self ._collate = self ._collate_tensor_dataset
80105 else :
106+ # If the dataset is a PinaDataset, use this collate function
81107 self ._collate = self ._collate_graph_dataset
82108
83- def _collate_custom_dataloader (self , batch ):
109+ def _collate_pina_dataloader (self , batch ):
110+ """
111+ Function used to create a batch when automatic batching is disabled.
112+
113+ :param list(int) batch: List of integers representing the indices of
114+ the data points to be fetched.
115+ :return: Dictionary containing the data points fetched from the dataset.
116+ :rtype: dict
117+ """
118+ # Call the fetch_from_idx_list method of the dataset
84119 return self .dataset .fetch_from_idx_list (batch )
85120
86- def _collate_standard_dataloader (self , batch ):
121+ def _collate_torch_dataloader (self , batch ):
87122 """
88123 Function used to collate the batch
89124 """
@@ -112,22 +147,56 @@ def _collate_standard_dataloader(self, batch):
112147
113148 @staticmethod
114149 def _collate_tensor_dataset (data_list ):
150+ """
151+ Function used to collate the data when the dataset is a
152+ `PinaTensorDataset`.
153+
154+ :param data_list: List of `torch.Tensor` or `LabelTensor` to be
155+ collated.
156+ :type data_list: list(torch.Tensor) | list(LabelTensor)
157+ :raises RuntimeError: If the data is not a `torch.Tensor` or a
158+ `LabelTensor`.
159+ :return: Batch of data
160+ :rtype: dict
161+ """
162+
115163 if isinstance (data_list [0 ], LabelTensor ):
116164 return LabelTensor .stack (data_list )
117165 if isinstance (data_list [0 ], torch .Tensor ):
118166 return torch .stack (data_list )
119167 raise RuntimeError ("Data must be Tensors or LabelTensor " )
120168
121169 def _collate_graph_dataset (self , data_list ):
170+ """
171+ Function used to collate the data when the dataset is a
172+ `PinaGraphDataset`.
173+
174+ :param data_list: List of `Data` or `Graph` to be collated.
175+ :type data_list: list(Data) | list(Graph)
176+ :raises RuntimeError: If the data is not a `Data` or a `Graph`.
177+ :return: Batch of data
178+ :rtype: dict
179+ """
180+
122181 if isinstance (data_list [0 ], LabelTensor ):
123182 return LabelTensor .cat (data_list )
124183 if isinstance (data_list [0 ], torch .Tensor ):
125184 return torch .cat (data_list )
126185 if isinstance (data_list [0 ], Data ):
127- return self .dataset .create_graph_batch (data_list )
186+ return self .dataset .create_batch (data_list )
128187 raise RuntimeError ("Data must be Tensors or LabelTensor or pyG Data" )
129188
130189 def __call__ (self , batch ):
190+ """
191+ Call the function to collate the batch, defined in __init__.
192+
193+ :param batch: list of indices or list of retrieved data
194+ :type batch: list(int) | list(dict)
195+ :return: Dictionary containing the data points fetched from the dataset,
196+ collated.
197+ :rtype: dict
198+ """
199+
131200 return self .callable_function (batch )
132201
133202
@@ -137,6 +206,16 @@ class PinaSampler:
137206 """
138207
139208 def __new__ (cls , dataset , shuffle ):
209+ """
210+ Create the sampler instance, according to shuffle and whether the
211+ environment is distributed or not.
212+
213+ :param PinaDataset dataset: The dataset to be sampled.
214+ :param bool shuffle: whether to shuffle the dataset or not before
215+ sampling.
216+ :return: The sampler instance.
217+ :rtype: torch.utils.data.Sampler
218+ """
140219
141220 if (
142221 torch .distributed .is_available ()
@@ -173,29 +252,24 @@ def __init__(
173252 """
174253 Initialize the object, creating datasets based on the input problem.
175254
176- :param problem: The problem defining the dataset.
177- :type problem: AbstractProblem
178- :param train_size: Fraction or number of elements in the training split.
179- :type train_size: float
180- :param test_size: Fraction or number of elements in the test split.
181- :type test_size: float
182- :param val_size: Fraction or number of elements in the validation split.
183- :type val_size: float
184- :param batch_size: Batch size used for training. If None, the entire
185- dataset is used per batch.
186- :type batch_size: int or None
187- :param shuffle: Whether to shuffle the dataset before splitting.
188- :type shuffle: bool
189- :param repeat: Whether to repeat the dataset indefinitely.
190- :type repeat: bool
255+ :param AbstractProblem problem: The problem containing the data on which
256+ to train/test the model.
257+ :param float train_size: Fraction or number of elements in the training
258+ split.
259+ :param float test_size: Fraction or number of elements in the test
260+ split.
261+ :param float val_size: Fraction or number of elements in the validation
262+ split.
263+ :param batch_size: The batch size used for training. If `None`, the
264+ entire dataset is used per batch.
265+ :type batch_size: int | None
266+ :param bool shuffle: Whether to shuffle the dataset before splitting.
267+ :param bool repeat: Whether to repeat the dataset indefinitely.
191268 :param automatic_batching: Whether to enable automatic batching.
192- :type automatic_batching: bool
193- :param num_workers: Number of worker threads for data loading.
269+ :param int num_workers: Number of worker threads for data loading.
194270 Default 0 (serial loading)
195- :type num_workers: int
196- :param pin_memory: Whether to use pinned memory for faster data
271+ :param bool pin_memory: Whether to use pinned memory for faster data
197272 transfer to GPU. (Default False)
198- :type pin_memory: bool
199273 """
200274 super ().__init__ ()
201275
@@ -365,10 +439,14 @@ def _create_dataloader(self, split, dataset):
365439 sampler = PinaSampler (dataset , shuffle )
366440 if self .automatic_batching :
367441 collate = Collator (
368- self .find_max_conditions_lengths (split ), dataset = dataset
442+ self .find_max_conditions_lengths (split ),
443+ self .automatic_batching ,
444+ dataset = dataset ,
369445 )
370446 else :
371- collate = Collator (None , dataset = dataset )
447+ collate = Collator (
448+ None , self .automatic_batching , dataset = dataset
449+ )
372450 return DataLoader (
373451 dataset ,
374452 self .batch_size ,
@@ -413,23 +491,51 @@ def val_dataloader(self):
413491 def train_dataloader (self ):
414492 """
415493 Create the training dataloader
494+
495+ :return: The training dataloader
496+ :rtype: DataLoader
416497 """
417498 return self ._create_dataloader ("train" , self .train_dataset )
418499
419500 def test_dataloader (self ):
420501 """
421502 Create the testing dataloader
503+
504+ :return: The testing dataloader
505+ :rtype: DataLoader
422506 """
423507 return self ._create_dataloader ("test" , self .test_dataset )
424508
425509 @staticmethod
426510 def _transfer_batch_to_device_dummy (batch , device , dataloader_idx ):
511+ """
512+ Transfer the batch to the device. This method is called in the
513+ training loop and is used to transfer the batch to the device.
514+ This method is used when the batch size is None: batch has already
515+ been transferred to the device.
516+
517+ :param list(tuple) batch: list of tuple where the first element of the
518+ tuple is the condition name and the second element is the data.
519+ :param device: device to which the batch is transferred
520+ :type device: torch.device
521+ :param dataloader_idx: index of the dataloader
522+ :type dataloader_idx: int
523+ :return: The batch transferred to the device.
524+ :rtype: list(tuple)
525+ """
427526 return batch
428527
429528 def _transfer_batch_to_device (self , batch , device , dataloader_idx ):
430529 """
431530 Transfer the batch to the device. This method is called in the
432531 training loop and is used to transfer the batch to the device.
532+
533+ :param dict batch: The batch to be transferred to the device.
534+ :param device: The device to which the batch is transferred.
535+ :type device: torch.device
536+ :param int dataloader_idx: The index of the dataloader.
537+ :return: The batch transferred to the device.
538+ :rtype: list(tuple)
433539 """
434540 batch = [
435541 (
@@ -456,13 +562,16 @@ def _check_slit_sizes(train_size, test_size, val_size):
456562 @property
457563 def input (self ):
458564 """
459- # TODO
565+ Return all the input points coming from all the datasets.
566+
567+ :return: The input points for training.
568+ :rtype dict
460569 """
461570 to_return = {}
462571 if hasattr (self , "train_dataset" ) and self .train_dataset is not None :
463572 to_return ["train" ] = self .train_dataset .input
464573 if hasattr (self , "val_dataset" ) and self .val_dataset is not None :
465574 to_return ["val" ] = self .val_dataset .input
466575 if hasattr (self , "test_dataset" ) and self .test_dataset is not None :
467- to_return = self .test_dataset .input
576+ to_return [ "test" ] = self .test_dataset .input
468577 return to_return
0 commit comments