Skip to content

Commit dd28d6c

Browse files
committed
Documentation and docstring graph
1 parent e541af0 commit dd28d6c

File tree

1 file changed

+61
-28
lines changed

1 file changed

+61
-28
lines changed

pina/graph.py

Lines changed: 61 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ def __new__(
1919
**kwargs,
2020
):
2121
"""
22+
Instantiates a new instance of the Graph class, performing type
23+
consistency checks.
24+
2225
:param kwargs: Parameters to construct the Graph object.
2326
:return: A new instance of the Graph class.
2427
:rtype: Graph
@@ -42,7 +45,10 @@ def __init__(
4245
**kwargs,
4346
):
4447
"""
45-
Initialize the Graph object.
48+
Initialize the Graph object by setting the node features, edge index,
49+
edge attributes, and positions. The edge index is preprocessed to make
50+
the graph undirected if required. For more details, see the
51+
:meth: `torch_geometric.data.Data`
4652
4753
:param x: Optional tensor of node features (N, F) where F is the number
4854
of features per node.
@@ -69,6 +75,13 @@ def __init__(
6975
)
7076

7177
def _check_type_consistency(self, **kwargs):
78+
"""
79+
Check the consistency of the types of the input data.
80+
81+
:param kwargs: Attributes to be checked for consistency.
82+
:type kwargs: dict
83+
"""
84+
7285
# default types, specified in cls.__new__, by default they are Nont
7386
# if specified in **kwargs they get override
7487
x, pos, edge_index, edge_attr = None, None, None, None
@@ -92,8 +105,10 @@ def _check_type_consistency(self, **kwargs):
92105
def _check_pos_consistency(pos):
93106
"""
94107
Check if the position tensor is consistent.
108+
95109
:param torch.Tensor pos: The position tensor.
96110
"""
111+
97112
if pos is not None:
98113
check_consistency(pos, (torch.Tensor, LabelTensor))
99114
if pos.ndim != 2:
@@ -103,8 +118,10 @@ def _check_pos_consistency(pos):
103118
def _check_edge_index_consistency(edge_index):
104119
"""
105120
Check if the edge index is consistent.
121+
106122
:param torch.Tensor edge_index: The edge index tensor.
107123
"""
124+
108125
check_consistency(edge_index, (torch.Tensor, LabelTensor))
109126
if edge_index.ndim != 2:
110127
raise ValueError("edge_index must be a 2D tensor.")
@@ -114,11 +131,13 @@ def _check_edge_index_consistency(edge_index):
114131
@staticmethod
115132
def _check_edge_attr_consistency(edge_attr, edge_index):
116133
"""
117-
Check if the edge attr is consistent.
118-
:param torch.Tensor edge_attr: The edge attribute tensor.
134+
Check if the edge attribute tensor is consistent in type and shape
135+
with the edge index.
119136
137+
:param torch.Tensor edge_attr: The edge attribute tensor.
120138
:param torch.Tensor edge_index: The edge index tensor.
121139
"""
140+
122141
if edge_attr is not None:
123142
check_consistency(edge_attr, (torch.Tensor, LabelTensor))
124143
if edge_attr.ndim != 2:
@@ -134,10 +153,13 @@ def _check_edge_attr_consistency(edge_attr, edge_index):
134153
@staticmethod
135154
def _check_x_consistency(x, pos=None):
136155
"""
137-
Check if the input tensor x is consistent with the position tensor pos.
156+
Check if the input tensor x is consistent with the position tensor
157+
`pos`.
158+
138159
:param torch.Tensor x: The input tensor.
139160
:param torch.Tensor pos: The position tensor.
140161
"""
162+
141163
if x is not None:
142164
check_consistency(x, (torch.Tensor, LabelTensor))
143165
if x.ndim != 2:
@@ -152,22 +174,24 @@ def _check_x_consistency(x, pos=None):
152174
@staticmethod
153175
def _preprocess_edge_index(edge_index, undirected):
154176
"""
155-
Preprocess the edge index.
177+
Preprocess the edge index to make the graph undirected (if required).
178+
156179
:param torch.Tensor edge_index: The edge index.
157180
:param bool undirected: Whether the graph is undirected.
158181
:return: The preprocessed edge index.
159182
:rtype: torch.Tensor
160183
"""
184+
161185
if undirected:
162186
edge_index = to_undirected(edge_index)
163187
return edge_index
164188

165189
def extract(self, labels, attr="x"):
166190
"""
167-
Perform extraction of labels on node features (x)
191+
Perform extraction of labels from the attribute specified by `attr`.
168192
169193
:param labels: Labels to extract
170-
:type labels: list[str] | tuple[str] | str
194+
:type labels: list[str] | tuple[str] | str | dict
171195
:return: Batch object with extraction performed on x
172196
:rtype: PinaBatch
173197
"""
@@ -193,21 +217,23 @@ def __new__(
193217
**kwargs,
194218
):
195219
"""
196-
Creates a new instance of the Graph class.
220+
Compute the edge attributes and create a new instance of the Graph class.
197221
198222
:param pos: A tensor of shape (N, D) representing the positions of N
199223
points in D-dimensional space.
200-
:type pos: torch.Tensor | LabelTensor
224+
:type pos: torch.Tensor or LabelTensor
201225
:param edge_index: A tensor of shape (2, E) representing the indices of
202226
the graph's edges.
203227
:type edge_index: torch.Tensor
204-
:param x: Optional tensor of node features (N, F) where F is the number
205-
of features per node.
206-
:type x: torch.Tensor, LabelTensor
207-
:param bool edge_attr: Optional edge attributes (E, F) where F is the
208-
number of features per edge.
209-
:param callable custom_edge_func: A custom function to compute edge
210-
attributes.
228+
:param x: Optional tensor of node features of shape (N, F), where F is
229+
the number of features per node.
230+
:type x: torch.Tensor | LabelTensor, optional
231+
:param edge_attr: Optional tensor of edge attributes of shape (E, F),
232+
where F is the number of features per edge.
233+
:type edge_attr: torch.Tensor, optional
234+
:param custom_edge_func: A custom function to compute edge attributes.
235+
If provided, overrides `edge_attr`.
236+
:type custom_edge_func: callable, optional
211237
:param kwargs: Additional keyword arguments passed to the Graph class
212238
constructor.
213239
:return: A Graph instance constructed using the provided information.
@@ -249,18 +275,18 @@ class RadiusGraph(GraphBuilder):
249275

250276
def __new__(cls, pos, radius, **kwargs):
251277
"""
252-
Creates a new instance of the Graph class using a radius-based graph
253-
construction.
278+
Extends the `GraphBuilder` class to compute edge_index based on a
279+
radius. Each point is connected to all the points within the radius.
254280
255281
:param pos: A tensor of shape (N, D) representing the positions of N
256282
points in D-dimensional space.
257-
:type pos: torch.Tensor | LabelTensor
258-
:param float radius: The radius within which points are connected.
259-
:Keyword Arguments:
260-
The additional keyword arguments to be passed to GraphBuilder
261-
and Graph classes
262-
:return: Graph instance containg the information passed in input and
263-
the computed edge_index
283+
:type pos: torch.Tensor or LabelTensor
284+
:param radius: The radius within which points are connected.
285+
:type radius: float
286+
:param kwargs: Additional keyword arguments to be passed to the
287+
`GraphBuilder` and `Graph` constructors.
288+
:return: A `Graph` instance containing the input information and the
289+
computed edge_index.
264290
:rtype: Graph
265291
"""
266292
edge_index = cls.compute_radius_graph(pos, radius)
@@ -269,7 +295,8 @@ def __new__(cls, pos, radius, **kwargs):
269295
@staticmethod
270296
def compute_radius_graph(points, radius):
271297
"""
272-
Computes a radius-based graph for a given set of points.
298+
Computes edge_index for a given set of points base on the radius.
299+
Each point is connected to all the points within the radius.
273300
274301
:param points: A tensor of shape (N, D) representing the positions of
275302
N points in D-dimensional space.
@@ -295,7 +322,7 @@ class KNNGraph(GraphBuilder):
295322
def __new__(cls, pos, neighbours, **kwargs):
296323
"""
297324
Creates a new instance of the Graph class using k-nearest neighbors
298-
to compute edge_index.
325+
algorithm to define the edges.
299326
300327
:param pos: A tensor of shape (N, D) representing the positions of N
301328
points in D-dimensional space.
@@ -323,8 +350,9 @@ def compute_knn_graph(points, k):
323350
N points in D-dimensional space.
324351
:type points: torch.Tensor | LabelTensor
325352
:param int k: The number of nearest neighbors to find for each point.
326-
:rtype torch.Tensor: A tensor of shape (2, E), where E is the number of
353+
:return: A tensor of shape (2, E), where E is the number of
327354
edges, representing the edge indices of the KNN graph.
355+
:rtype: torch.Tensor
328356
"""
329357

330358
dist = torch.cdist(points, points, p=2)
@@ -343,6 +371,11 @@ class LabelBatch(Batch):
343371
def from_data_list(cls, data_list):
344372
"""
345373
Create a Batch object from a list of Data objects.
374+
375+
:param data_list: List of Data/Graph objects
376+
:type data_list: list[Data] | list[Graph]
377+
:return: A Batch object containing the data in the list
378+
:rtype: Batch
346379
"""
347380
# Store the labels of Data/Graph objects (all data have the same labels)
348381
# If the data do not contain labels, labels is an empty dictionary,

0 commit comments

Comments
 (0)