Skip to content

Commit 1a60bce

Browse files
committed
Documentation and docstring data
1 parent dd28d6c commit 1a60bce

File tree

2 files changed

+278
-54
lines changed

2 files changed

+278
-54
lines changed

pina/data/data_module.py

Lines changed: 147 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,22 @@ class DummyDataloader:
2323

2424
def __init__(self, dataset):
2525
"""
26-
param dataset: The dataset object to be processed.
27-
:notes:
28-
- **Distributed Environment**:
29-
- Divides the dataset across processes using the
30-
rank and world size.
31-
- Fetches only the portion of data corresponding to
32-
the current process.
33-
- **Non-Distributed Environment**:
34-
- Fetches the entire dataset.
26+
Preprare a dataloader object which will return the entire dataset
27+
in a single batch. Depending on the number of GPUs, the dataset we
28+
have the following cases:
29+
30+
- **Distributed Environment** (multiple GPUs):
31+
- Divides the dataset across processes using the rank and world size.
32+
- Fetches only the portion of data corresponding to the current process.
33+
- **Non-Distributed Environment** (single GPU):
34+
- Fetches the entire dataset.
35+
36+
:param dataset: The dataset object to be processed.
37+
:type dataset: PinaDataset
38+
39+
.. note:: This data loader is used when the batch size is None.
3540
"""
41+
3642
if (
3743
torch.distributed.is_available()
3844
and torch.distributed.is_initialized()
@@ -67,23 +73,50 @@ class Collator:
6773
Class used to collate the batch
6874
"""
6975

70-
def __init__(self, max_conditions_lengths, dataset=None):
76+
def __init__(
77+
self, max_conditions_lengths, automatic_batching, dataset=None
78+
):
79+
"""
80+
Initialize the object, setting the collate function based on whether
81+
automatic batching is enabled or not.
82+
83+
:param dict max_conditions_lengths: dict containing the maximum number of
84+
data points to consider in a single batch for each condition.
85+
:param PinaDataset dataset: The dataset where the data is stored.
86+
"""
87+
7188
self.max_conditions_lengths = max_conditions_lengths
89+
# Set the collate function based on the batching strategy
90+
# collate_pina_dataloader is used when automatic batching is disabled
91+
# collate_torch_dataloader is used when automatic batching is enabled
7292
self.callable_function = (
73-
self._collate_custom_dataloader
74-
if max_conditions_lengths is None
75-
else (self._collate_standard_dataloader)
93+
self._collate_torch_dataloader
94+
if automatic_batching
95+
else (self._collate_pina_dataloader)
7696
)
7797
self.dataset = dataset
98+
99+
# Set the function which performs the actual collation
78100
if isinstance(self.dataset, PinaTensorDataset):
101+
# If the dataset is a PinaTensorDataset, use this collate function
79102
self._collate = self._collate_tensor_dataset
80103
else:
104+
# If the dataset is a PinaDataset, use this collate function
81105
self._collate = self._collate_graph_dataset
82106

83-
def _collate_custom_dataloader(self, batch):
107+
def _collate_pina_dataloader(self, batch):
108+
"""
109+
Function used to create a batch when automatic batching is disabled.
110+
111+
:param list(int) batch: List of integers representing the indices of
112+
the data points to be fetched.
113+
:return: Dictionary containing the data points fetched from the dataset.
114+
:rtype: dict
115+
"""
116+
# Call the fetch_from_idx_list method of the dataset
84117
return self.dataset.fetch_from_idx_list(batch)
85118

86-
def _collate_standard_dataloader(self, batch):
119+
def _collate_torch_dataloader(self, batch):
87120
"""
88121
Function used to collate the batch
89122
"""
@@ -112,22 +145,56 @@ def _collate_standard_dataloader(self, batch):
112145

113146
@staticmethod
114147
def _collate_tensor_dataset(data_list):
148+
"""
149+
Function used to collate the data when the dataset is a
150+
`PinaTensorDataset`.
151+
152+
:param data_list: List of `torch.Tensor` or `LabelTensor` to be
153+
collated.
154+
:type data_list: list(torch.Tensor) | list(LabelTensor)
155+
:raises RuntimeError: If the data is not a `torch.Tensor` or a
156+
`LabelTensor`.
157+
:return: Batch of data
158+
:rtype: dict
159+
"""
160+
115161
if isinstance(data_list[0], LabelTensor):
116162
return LabelTensor.stack(data_list)
117163
if isinstance(data_list[0], torch.Tensor):
118164
return torch.stack(data_list)
119165
raise RuntimeError("Data must be Tensors or LabelTensor ")
120166

121167
def _collate_graph_dataset(self, data_list):
168+
"""
169+
Function used to collate the data when the dataset is a
170+
`PinaGraphDataset`.
171+
172+
:param data_list: List of `Data` or `Graph` to be collated.
173+
:type data_list: list(Data) | list(Graph)
174+
:raises RuntimeError: If the data is not a `Data` or a `Graph`.
175+
:return: Batch of data
176+
:rtype: dict
177+
"""
178+
122179
if isinstance(data_list[0], LabelTensor):
123180
return LabelTensor.cat(data_list)
124181
if isinstance(data_list[0], torch.Tensor):
125182
return torch.cat(data_list)
126183
if isinstance(data_list[0], Data):
127-
return self.dataset.create_graph_batch(data_list)
184+
return self.dataset.create_batch(data_list)
128185
raise RuntimeError("Data must be Tensors or LabelTensor or pyG Data")
129186

130187
def __call__(self, batch):
188+
"""
189+
Call the function to collate the batch, defined in __init__.
190+
191+
:param batch: list of indices or list of retrieved data
192+
:type batch: list(int) | list(dict)
193+
:return: Dictionary containing the data points fetched from the dataset,
194+
collated.
195+
:rtype: dict
196+
"""
197+
131198
return self.callable_function(batch)
132199

133200

@@ -137,6 +204,16 @@ class PinaSampler:
137204
"""
138205

139206
def __new__(cls, dataset, shuffle):
207+
"""
208+
Create the sampler instance, according to shuffle and whether the
209+
environment is distributed or not.
210+
211+
:param PinaDataset dataset: The dataset to be sampled.
212+
:param bool shuffle: whether to shuffle the dataset or not before
213+
sampling.
214+
:return: The sampler instance.
215+
:rtype: torch.utils.data.Sampler
216+
"""
140217

141218
if (
142219
torch.distributed.is_available()
@@ -173,29 +250,24 @@ def __init__(
173250
"""
174251
Initialize the object, creating datasets based on the input problem.
175252
176-
:param problem: The problem defining the dataset.
177-
:type problem: AbstractProblem
178-
:param train_size: Fraction or number of elements in the training split.
179-
:type train_size: float
180-
:param test_size: Fraction or number of elements in the test split.
181-
:type test_size: float
182-
:param val_size: Fraction or number of elements in the validation split.
183-
:type val_size: float
184-
:param batch_size: Batch size used for training. If None, the entire
185-
dataset is used per batch.
186-
:type batch_size: int or None
187-
:param shuffle: Whether to shuffle the dataset before splitting.
188-
:type shuffle: bool
189-
:param repeat: Whether to repeat the dataset indefinitely.
190-
:type repeat: bool
253+
:param AbstractProblem problem: The problem containing the data on which
254+
to train/test the model.
255+
:param float train_size: Fraction or number of elements in the training
256+
split.
257+
:param float test_size: Fraction or number of elements in the test
258+
split.
259+
:param float val_size: Fraction or number of elements in the validation
260+
split.
261+
:param batch_size: The batch size used for training. If `None`, the
262+
entire dataset is used per batch.
263+
:type batch_size: int | None
264+
:param bool shuffle: Whether to shuffle the dataset before splitting.
265+
:param bool repeat: Whether to repeat the dataset indefinitely.
191266
:param automatic_batching: Whether to enable automatic batching.
192-
:type automatic_batching: bool
193-
:param num_workers: Number of worker threads for data loading.
267+
:param int num_workers: Number of worker threads for data loading.
194268
Default 0 (serial loading)
195-
:type num_workers: int
196-
:param pin_memory: Whether to use pinned memory for faster data
269+
:param bool pin_memory: Whether to use pinned memory for faster data
197270
transfer to GPU. (Default False)
198-
:type pin_memory: bool
199271
"""
200272
super().__init__()
201273

@@ -365,10 +437,14 @@ def _create_dataloader(self, split, dataset):
365437
sampler = PinaSampler(dataset, shuffle)
366438
if self.automatic_batching:
367439
collate = Collator(
368-
self.find_max_conditions_lengths(split), dataset=dataset
440+
self.find_max_conditions_lengths(split),
441+
self.automatic_batching,
442+
dataset=dataset,
369443
)
370444
else:
371-
collate = Collator(None, dataset=dataset)
445+
collate = Collator(
446+
None, self.automatic_batching, dataset=dataset
447+
)
372448
return DataLoader(
373449
dataset,
374450
self.batch_size,
@@ -413,23 +489,51 @@ def val_dataloader(self):
413489
def train_dataloader(self):
414490
"""
415491
Create the training dataloader
492+
493+
:return: The training dataloader
494+
:rtype: DataLoader
416495
"""
417496
return self._create_dataloader("train", self.train_dataset)
418497

419498
def test_dataloader(self):
420499
"""
421500
Create the testing dataloader
501+
502+
:return: The testing dataloader
503+
:rtype: DataLoader
422504
"""
423505
return self._create_dataloader("test", self.test_dataset)
424506

425507
@staticmethod
426508
def _transfer_batch_to_device_dummy(batch, device, dataloader_idx):
509+
"""
510+
Transfer the batch to the device. This method is called in the
511+
training loop and is used to transfer the batch to the device.
512+
This method is used when the batch size is None: batch has already
513+
been transferred to the device.
514+
515+
:param list(tuple) batch: list of tuple where the first element of the
516+
tuple is the condition name and the second element is the data.
517+
:param device: device to which the batch is transferred
518+
:type device: torch.device
519+
:param dataloader_idx: index of the dataloader
520+
:type dataloader_idx: int
521+
:return: The batch transferred to the device.
522+
:rtype: list(tuple)
523+
"""
427524
return batch
428525

429526
def _transfer_batch_to_device(self, batch, device, dataloader_idx):
430527
"""
431528
Transfer the batch to the device. This method is called in the
432529
training loop and is used to transfer the batch to the device.
530+
531+
:param dict batch: The batch to be transferred to the device.
532+
:param device: The device to which the batch is transferred.
533+
:type device: torch.device
534+
:param int dataloader_idx: The index of the dataloader.
535+
:return: The batch transferred to the device.
536+
:rtype: list(tuple)
433537
"""
434538
batch = [
435539
(
@@ -456,13 +560,16 @@ def _check_slit_sizes(train_size, test_size, val_size):
456560
@property
457561
def input(self):
458562
"""
459-
# TODO
563+
Return all the input points coming from all the datasets.
564+
565+
:return: The input points for training.
566+
:rtype dict
460567
"""
461568
to_return = {}
462569
if hasattr(self, "train_dataset") and self.train_dataset is not None:
463570
to_return["train"] = self.train_dataset.input
464571
if hasattr(self, "val_dataset") and self.val_dataset is not None:
465572
to_return["val"] = self.val_dataset.input
466573
if hasattr(self, "test_dataset") and self.test_dataset is not None:
467-
to_return = self.test_dataset.input
574+
to_return["test"] = self.test_dataset.input
468575
return to_return

0 commit comments

Comments
 (0)