Skip to content

Commit 90ac1c3

Browse files
FilippoOlivodario-coscia
authored andcommitted
Update doc data
1 parent c91cdcc commit 90ac1c3

File tree

2 files changed

+17
-16
lines changed

2 files changed

+17
-16
lines changed

pina/data/data_module.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def _collate_pina_dataloader(self, batch):
112112
"""
113113
Function used to create a batch when automatic batching is disabled.
114114
115-
:param list(int) batch: List of integers representing the indices of
115+
:param list[int] batch: List of integers representing the indices of
116116
the data points to be fetched.
117117
:return: Dictionary containing the data points fetched from the dataset.
118118
:rtype: dict
@@ -124,7 +124,7 @@ def _collate_torch_dataloader(self, batch):
124124
"""
125125
Function used to collate the batch
126126
127-
:param list(dict) batch: List of retrieved data.
127+
:param list[dict] batch: List of retrieved data.
128128
:return: Dictionary containing the data points fetched from the dataset,
129129
collated.
130130
:rtype: dict
@@ -160,7 +160,7 @@ def _collate_tensor_dataset(data_list):
160160
:class:`PinaTensorDataset`.
161161
162162
:param data_list: Elements to be collated.
163-
:type data_list: list(torch.Tensor) | list(LabelTensor)
163+
:type data_list: list[torch.Tensor] | list[LabelTensor]
164164
:return: Batch of data.
165165
:rtype: dict
166166
@@ -180,7 +180,7 @@ def _collate_graph_dataset(self, data_list):
180180
:class:`PinaGraphDataset`.
181181
182182
:param data_list: Elememts to be collated.
183-
:type data_list: list(torch_geometric.data.Data) | list(Graph)
183+
:type data_list: list[torch_geometric.data.Data] | list[Graph]
184184
:return: Batch of data.
185185
:rtype: dict
186186
@@ -206,7 +206,7 @@ def __call__(self, batch):
206206
during class initialization.
207207
208208
:param batch: List of retrieved data or sampled indices.
209-
:type batch: list(int) | list(dict)
209+
:type batch: list[int] | list[dict]
210210
:return: Dictionary containing the data points fetched from the dataset,
211211
collated.
212212
:rtype: dict
@@ -582,12 +582,12 @@ def _transfer_batch_to_device_dummy(batch, device, dataloader_idx):
582582
Transfer the batch to the device. This method is used when the batch
583583
size is None: batch has already been transferred to the device.
584584
585-
:param list(tuple) batch: list of tuple where the first element of the
585+
:param list[tuple] batch: List of tuple where the first element of the
586586
tuple is the condition name and the second element is the data.
587-
:param torch.device device: device to which the batch is transferred.
588-
:param int dataloader_idx: index of the dataloader.
587+
:param torch.device device: Device to which the batch is transferred.
588+
:param int dataloader_idx: Index of the dataloader.
589589
:return: The batch transferred to the device.
590-
:rtype: list(tuple)
590+
:rtype: list[tuple]
591591
"""
592592

593593
return batch
@@ -602,7 +602,7 @@ def _transfer_batch_to_device(self, batch, device, dataloader_idx):
602602
transferred.
603603
:param int dataloader_idx: The index of the dataloader.
604604
:return: The batch transferred to the device.
605-
:rtype: list(tuple)
605+
:rtype: list[tuple]
606606
"""
607607

608608
batch = [

pina/data/dataset.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def fetch_from_idx_list(self, idx):
175175
Return data from the dataset given a list of indices.
176176
177177
:param idx: List of indices.
178-
:type idx: list
178+
:type idx: list[int]
179179
:return: A dictionary containing the data at the given indices.
180180
:rtype: dict
181181
"""
@@ -216,7 +216,7 @@ def _retrive_data(self, data, idx_list):
216216
:param data: Dictionary containing the data
217217
(only torch.Tensor/LableTensor).
218218
:type data: dict
219-
:param list(int) idx_list: indices to retrieve.
219+
:param list[int] idx_list: indices to retrieve.
220220
:return: Dictionary containing the data at the given indices.
221221
:rtype: dict
222222
"""
@@ -246,7 +246,7 @@ def _create_graph_batch(self, data):
246246
:class:`torch_geometric.data.Data` objects.
247247
248248
:param data: List of items to collate in a single batch.
249-
:type data: list(torch_geometric.data.Data) | list(Graph)
249+
:type data: list[torch_geometric.data.Data] | list[Graph]
250250
:return: LabelBatch object all the graph collated in a single batch
251251
disconnected graphs.
252252
:rtype: LabelBatch
@@ -256,7 +256,8 @@ def _create_graph_batch(self, data):
256256

257257
def _create_tensor_batch(self, data):
258258
"""
259-
Create a torch.Tensor object from a list of torch.Tensor objects.
259+
Reshape properly ``data`` tensor to be processed handle by the graph
260+
based models.
260261
261262
:param data: torch.Tensor object of shape (N, ...) where N is the
262263
number of data points.
@@ -273,7 +274,7 @@ def create_batch(self, data):
273274
objects.
274275
275276
:param data: List of items to collate in a single batch.
276-
:type data: list
277+
:type data: list[torch_geometric.data.Data] | list[Graph]
277278
:return: Batch object.
278279
:rtype: Batch | PinaBatch
279280
"""
@@ -288,7 +289,7 @@ def _retrive_data(self, data, idx_list):
288289
Retrieve data from the dataset given a list of indices.
289290
290291
:param dict data: Dictionary containing the data.
291-
:param list idx_list: List of indices to retrieve.
292+
:param list[int] idx_list: List of indices to retrieve.
292293
:return: Dictionary containing the data at the given indices.
293294
:rtype: dict
294295
"""

0 commit comments

Comments
 (0)