Skip to content

Commit 7f46cc0

Browse files
committed
Documentation and docstring graph and data
1 parent e541af0 commit 7f46cc0

File tree

3 files changed

+342
-83
lines changed

3 files changed

+342
-83
lines changed

pina/data/data_module.py

Lines changed: 149 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,24 @@ class DummyDataloader:
2323

2424
def __init__(self, dataset):
2525
"""
26-
param dataset: The dataset object to be processed.
27-
:notes:
28-
- **Distributed Environment**:
29-
- Divides the dataset across processes using the
30-
rank and world size.
31-
- Fetches only the portion of data corresponding to
32-
the current process.
33-
- **Non-Distributed Environment**:
34-
- Fetches the entire dataset.
26+
Preprare a dataloader object which will return the entire dataset
27+
in a single batch. Depending on the number of GPUs, the dataset we
28+
have the following cases:
29+
30+
- **Distributed Environment** (multiple GPUs):
31+
- Divides the dataset across processes using the rank and world
32+
size.
33+
- Fetches only the portion of data corresponding to the current
34+
process.
35+
- **Non-Distributed Environment** (single GPU):
36+
- Fetches the entire dataset.
37+
38+
:param dataset: The dataset object to be processed.
39+
:type dataset: PinaDataset
40+
41+
.. note:: This data loader is used when the batch size is None.
3542
"""
43+
3644
if (
3745
torch.distributed.is_available()
3846
and torch.distributed.is_initialized()
@@ -67,23 +75,50 @@ class Collator:
6775
Class used to collate the batch
6876
"""
6977

70-
def __init__(self, max_conditions_lengths, dataset=None):
78+
def __init__(
79+
self, max_conditions_lengths, automatic_batching, dataset=None
80+
):
81+
"""
82+
Initialize the object, setting the collate function based on whether
83+
automatic batching is enabled or not.
84+
85+
:param dict max_conditions_lengths: dict containing the maximum number
86+
of data points to consider in a single batch for each condition.
87+
:param PinaDataset dataset: The dataset where the data is stored.
88+
"""
89+
7190
self.max_conditions_lengths = max_conditions_lengths
91+
# Set the collate function based on the batching strategy
92+
# collate_pina_dataloader is used when automatic batching is disabled
93+
# collate_torch_dataloader is used when automatic batching is enabled
7294
self.callable_function = (
73-
self._collate_custom_dataloader
74-
if max_conditions_lengths is None
75-
else (self._collate_standard_dataloader)
95+
self._collate_torch_dataloader
96+
if automatic_batching
97+
else (self._collate_pina_dataloader)
7698
)
7799
self.dataset = dataset
100+
101+
# Set the function which performs the actual collation
78102
if isinstance(self.dataset, PinaTensorDataset):
103+
# If the dataset is a PinaTensorDataset, use this collate function
79104
self._collate = self._collate_tensor_dataset
80105
else:
106+
# If the dataset is a PinaDataset, use this collate function
81107
self._collate = self._collate_graph_dataset
82108

83-
def _collate_custom_dataloader(self, batch):
109+
def _collate_pina_dataloader(self, batch):
110+
"""
111+
Function used to create a batch when automatic batching is disabled.
112+
113+
:param list(int) batch: List of integers representing the indices of
114+
the data points to be fetched.
115+
:return: Dictionary containing the data points fetched from the dataset.
116+
:rtype: dict
117+
"""
118+
# Call the fetch_from_idx_list method of the dataset
84119
return self.dataset.fetch_from_idx_list(batch)
85120

86-
def _collate_standard_dataloader(self, batch):
121+
def _collate_torch_dataloader(self, batch):
87122
"""
88123
Function used to collate the batch
89124
"""
@@ -112,22 +147,56 @@ def _collate_standard_dataloader(self, batch):
112147

113148
@staticmethod
114149
def _collate_tensor_dataset(data_list):
150+
"""
151+
Function used to collate the data when the dataset is a
152+
`PinaTensorDataset`.
153+
154+
:param data_list: List of `torch.Tensor` or `LabelTensor` to be
155+
collated.
156+
:type data_list: list(torch.Tensor) | list(LabelTensor)
157+
:raises RuntimeError: If the data is not a `torch.Tensor` or a
158+
`LabelTensor`.
159+
:return: Batch of data
160+
:rtype: dict
161+
"""
162+
115163
if isinstance(data_list[0], LabelTensor):
116164
return LabelTensor.stack(data_list)
117165
if isinstance(data_list[0], torch.Tensor):
118166
return torch.stack(data_list)
119167
raise RuntimeError("Data must be Tensors or LabelTensor ")
120168

121169
def _collate_graph_dataset(self, data_list):
170+
"""
171+
Function used to collate the data when the dataset is a
172+
`PinaGraphDataset`.
173+
174+
:param data_list: List of `Data` or `Graph` to be collated.
175+
:type data_list: list(Data) | list(Graph)
176+
:raises RuntimeError: If the data is not a `Data` or a `Graph`.
177+
:return: Batch of data
178+
:rtype: dict
179+
"""
180+
122181
if isinstance(data_list[0], LabelTensor):
123182
return LabelTensor.cat(data_list)
124183
if isinstance(data_list[0], torch.Tensor):
125184
return torch.cat(data_list)
126185
if isinstance(data_list[0], Data):
127-
return self.dataset.create_graph_batch(data_list)
186+
return self.dataset.create_batch(data_list)
128187
raise RuntimeError("Data must be Tensors or LabelTensor or pyG Data")
129188

130189
def __call__(self, batch):
190+
"""
191+
Call the function to collate the batch, defined in __init__.
192+
193+
:param batch: list of indices or list of retrieved data
194+
:type batch: list(int) | list(dict)
195+
:return: Dictionary containing the data points fetched from the dataset,
196+
collated.
197+
:rtype: dict
198+
"""
199+
131200
return self.callable_function(batch)
132201

133202

@@ -137,6 +206,16 @@ class PinaSampler:
137206
"""
138207

139208
def __new__(cls, dataset, shuffle):
209+
"""
210+
Create the sampler instance, according to shuffle and whether the
211+
environment is distributed or not.
212+
213+
:param PinaDataset dataset: The dataset to be sampled.
214+
:param bool shuffle: whether to shuffle the dataset or not before
215+
sampling.
216+
:return: The sampler instance.
217+
:rtype: torch.utils.data.Sampler
218+
"""
140219

141220
if (
142221
torch.distributed.is_available()
@@ -173,29 +252,24 @@ def __init__(
173252
"""
174253
Initialize the object, creating datasets based on the input problem.
175254
176-
:param problem: The problem defining the dataset.
177-
:type problem: AbstractProblem
178-
:param train_size: Fraction or number of elements in the training split.
179-
:type train_size: float
180-
:param test_size: Fraction or number of elements in the test split.
181-
:type test_size: float
182-
:param val_size: Fraction or number of elements in the validation split.
183-
:type val_size: float
184-
:param batch_size: Batch size used for training. If None, the entire
185-
dataset is used per batch.
186-
:type batch_size: int or None
187-
:param shuffle: Whether to shuffle the dataset before splitting.
188-
:type shuffle: bool
189-
:param repeat: Whether to repeat the dataset indefinitely.
190-
:type repeat: bool
255+
:param AbstractProblem problem: The problem containing the data on which
256+
to train/test the model.
257+
:param float train_size: Fraction or number of elements in the training
258+
split.
259+
:param float test_size: Fraction or number of elements in the test
260+
split.
261+
:param float val_size: Fraction or number of elements in the validation
262+
split.
263+
:param batch_size: The batch size used for training. If `None`, the
264+
entire dataset is used per batch.
265+
:type batch_size: int | None
266+
:param bool shuffle: Whether to shuffle the dataset before splitting.
267+
:param bool repeat: Whether to repeat the dataset indefinitely.
191268
:param automatic_batching: Whether to enable automatic batching.
192-
:type automatic_batching: bool
193-
:param num_workers: Number of worker threads for data loading.
269+
:param int num_workers: Number of worker threads for data loading.
194270
Default 0 (serial loading)
195-
:type num_workers: int
196-
:param pin_memory: Whether to use pinned memory for faster data
271+
:param bool pin_memory: Whether to use pinned memory for faster data
197272
transfer to GPU. (Default False)
198-
:type pin_memory: bool
199273
"""
200274
super().__init__()
201275

@@ -365,10 +439,14 @@ def _create_dataloader(self, split, dataset):
365439
sampler = PinaSampler(dataset, shuffle)
366440
if self.automatic_batching:
367441
collate = Collator(
368-
self.find_max_conditions_lengths(split), dataset=dataset
442+
self.find_max_conditions_lengths(split),
443+
self.automatic_batching,
444+
dataset=dataset,
369445
)
370446
else:
371-
collate = Collator(None, dataset=dataset)
447+
collate = Collator(
448+
None, self.automatic_batching, dataset=dataset
449+
)
372450
return DataLoader(
373451
dataset,
374452
self.batch_size,
@@ -413,23 +491,51 @@ def val_dataloader(self):
413491
def train_dataloader(self):
414492
"""
415493
Create the training dataloader
494+
495+
:return: The training dataloader
496+
:rtype: DataLoader
416497
"""
417498
return self._create_dataloader("train", self.train_dataset)
418499

419500
def test_dataloader(self):
420501
"""
421502
Create the testing dataloader
503+
504+
:return: The testing dataloader
505+
:rtype: DataLoader
422506
"""
423507
return self._create_dataloader("test", self.test_dataset)
424508

425509
@staticmethod
426510
def _transfer_batch_to_device_dummy(batch, device, dataloader_idx):
511+
"""
512+
Transfer the batch to the device. This method is called in the
513+
training loop and is used to transfer the batch to the device.
514+
This method is used when the batch size is None: batch has already
515+
been transferred to the device.
516+
517+
:param list(tuple) batch: list of tuple where the first element of the
518+
tuple is the condition name and the second element is the data.
519+
:param device: device to which the batch is transferred
520+
:type device: torch.device
521+
:param dataloader_idx: index of the dataloader
522+
:type dataloader_idx: int
523+
:return: The batch transferred to the device.
524+
:rtype: list(tuple)
525+
"""
427526
return batch
428527

429528
def _transfer_batch_to_device(self, batch, device, dataloader_idx):
430529
"""
431530
Transfer the batch to the device. This method is called in the
432531
training loop and is used to transfer the batch to the device.
532+
533+
:param dict batch: The batch to be transferred to the device.
534+
:param device: The device to which the batch is transferred.
535+
:type device: torch.device
536+
:param int dataloader_idx: The index of the dataloader.
537+
:return: The batch transferred to the device.
538+
:rtype: list(tuple)
433539
"""
434540
batch = [
435541
(
@@ -456,13 +562,16 @@ def _check_slit_sizes(train_size, test_size, val_size):
456562
@property
457563
def input(self):
458564
"""
459-
# TODO
565+
Return all the input points coming from all the datasets.
566+
567+
:return: The input points for training.
568+
:rtype dict
460569
"""
461570
to_return = {}
462571
if hasattr(self, "train_dataset") and self.train_dataset is not None:
463572
to_return["train"] = self.train_dataset.input
464573
if hasattr(self, "val_dataset") and self.val_dataset is not None:
465574
to_return["val"] = self.val_dataset.input
466575
if hasattr(self, "test_dataset") and self.test_dataset is not None:
467-
to_return = self.test_dataset.input
576+
to_return["test"] = self.test_dataset.input
468577
return to_return

0 commit comments

Comments
 (0)