@@ -19,6 +19,9 @@ def __new__(
1919 ** kwargs ,
2020 ):
2121 """
22+ Instantiates a new instance of the Graph class, performing type
23+ consistency checks.
24+
2225 :param kwargs: Parameters to construct the Graph object.
2326 :return: A new instance of the Graph class.
2427 :rtype: Graph
@@ -42,7 +45,10 @@ def __init__(
4245 ** kwargs ,
4346 ):
4447 """
45- Initialize the Graph object.
48+ Initialize the Graph object by setting the node features, edge index,
49+ edge attributes, and positions. The edge index is preprocessed to make
50+ the graph undirected if required. For more details, see the
51+ :meth: `torch_geometric.data.Data`
4652
4753 :param x: Optional tensor of node features (N, F) where F is the number
4854 of features per node.
@@ -69,6 +75,13 @@ def __init__(
6975 )
7076
7177 def _check_type_consistency (self , ** kwargs ):
78+ """
79+ Check the consistency of the types of the input data.
80+
81+ :param kwargs: Attributes to be checked for consistency.
82+ :type kwargs: dict
83+ """
84+
7285 # default types, specified in cls.__new__, by default they are Nont
7386 # if specified in **kwargs they get override
7487 x , pos , edge_index , edge_attr = None , None , None , None
@@ -92,8 +105,10 @@ def _check_type_consistency(self, **kwargs):
92105 def _check_pos_consistency (pos ):
93106 """
94107 Check if the position tensor is consistent.
108+
95109 :param torch.Tensor pos: The position tensor.
96110 """
111+
97112 if pos is not None :
98113 check_consistency (pos , (torch .Tensor , LabelTensor ))
99114 if pos .ndim != 2 :
@@ -103,8 +118,10 @@ def _check_pos_consistency(pos):
103118 def _check_edge_index_consistency (edge_index ):
104119 """
105120 Check if the edge index is consistent.
121+
106122 :param torch.Tensor edge_index: The edge index tensor.
107123 """
124+
108125 check_consistency (edge_index , (torch .Tensor , LabelTensor ))
109126 if edge_index .ndim != 2 :
110127 raise ValueError ("edge_index must be a 2D tensor." )
@@ -114,11 +131,13 @@ def _check_edge_index_consistency(edge_index):
114131 @staticmethod
115132 def _check_edge_attr_consistency (edge_attr , edge_index ):
116133 """
117- Check if the edge attr is consistent.
118- :param torch.Tensor edge_attr: The edge attribute tensor .
134+ Check if the edge attribute tensor is consistent in type and shape
135+ with the edge index .
119136
137+ :param torch.Tensor edge_attr: The edge attribute tensor.
120138 :param torch.Tensor edge_index: The edge index tensor.
121139 """
140+
122141 if edge_attr is not None :
123142 check_consistency (edge_attr , (torch .Tensor , LabelTensor ))
124143 if edge_attr .ndim != 2 :
@@ -134,10 +153,13 @@ def _check_edge_attr_consistency(edge_attr, edge_index):
134153 @staticmethod
135154 def _check_x_consistency (x , pos = None ):
136155 """
137- Check if the input tensor x is consistent with the position tensor pos.
156+ Check if the input tensor x is consistent with the position tensor
157+ `pos`.
158+
138159 :param torch.Tensor x: The input tensor.
139160 :param torch.Tensor pos: The position tensor.
140161 """
162+
141163 if x is not None :
142164 check_consistency (x , (torch .Tensor , LabelTensor ))
143165 if x .ndim != 2 :
@@ -152,22 +174,24 @@ def _check_x_consistency(x, pos=None):
152174 @staticmethod
153175 def _preprocess_edge_index (edge_index , undirected ):
154176 """
155- Preprocess the edge index.
177+ Preprocess the edge index to make the graph undirected (if required).
178+
156179 :param torch.Tensor edge_index: The edge index.
157180 :param bool undirected: Whether the graph is undirected.
158181 :return: The preprocessed edge index.
159182 :rtype: torch.Tensor
160183 """
184+
161185 if undirected :
162186 edge_index = to_undirected (edge_index )
163187 return edge_index
164188
165189 def extract (self , labels , attr = "x" ):
166190 """
167- Perform extraction of labels on node features (x)
191+ Perform extraction of labels from the attribute specified by `attr`.
168192
169193 :param labels: Labels to extract
170- :type labels: list[str] | tuple[str] | str
194+ :type labels: list[str] | tuple[str] | str | dict
171195 :return: Batch object with extraction performed on x
172196 :rtype: PinaBatch
173197 """
@@ -193,21 +217,23 @@ def __new__(
193217 ** kwargs ,
194218 ):
195219 """
196- Creates a new instance of the Graph class.
220+ Compute the edge attributes and create a new instance of the Graph class.
197221
198222 :param pos: A tensor of shape (N, D) representing the positions of N
199223 points in D-dimensional space.
200- :type pos: torch.Tensor | LabelTensor
224+ :type pos: torch.Tensor or LabelTensor
201225 :param edge_index: A tensor of shape (2, E) representing the indices of
202226 the graph's edges.
203227 :type edge_index: torch.Tensor
204- :param x: Optional tensor of node features (N, F) where F is the number
205- of features per node.
206- :type x: torch.Tensor, LabelTensor
207- :param bool edge_attr: Optional edge attributes (E, F) where F is the
208- number of features per edge.
209- :param callable custom_edge_func: A custom function to compute edge
210- attributes.
228+ :param x: Optional tensor of node features of shape (N, F), where F is
229+ the number of features per node.
230+ :type x: torch.Tensor | LabelTensor, optional
231+ :param edge_attr: Optional tensor of edge attributes of shape (E, F),
232+ where F is the number of features per edge.
233+ :type edge_attr: torch.Tensor, optional
234+ :param custom_edge_func: A custom function to compute edge attributes.
235+ If provided, overrides `edge_attr`.
236+ :type custom_edge_func: callable, optional
211237 :param kwargs: Additional keyword arguments passed to the Graph class
212238 constructor.
213239 :return: A Graph instance constructed using the provided information.
@@ -249,18 +275,18 @@ class RadiusGraph(GraphBuilder):
249275
250276 def __new__ (cls , pos , radius , ** kwargs ):
251277 """
252- Creates a new instance of the Graph class using a radius- based graph
253- construction .
278+ Extends the `GraphBuilder` class to compute edge_index based on a
279+ radius. Each point is connected to all the points within the radius .
254280
255281 :param pos: A tensor of shape (N, D) representing the positions of N
256282 points in D-dimensional space.
257- :type pos: torch.Tensor | LabelTensor
258- :param float radius: The radius within which points are connected.
259- :Keyword Arguments:
260- The additional keyword arguments to be passed to GraphBuilder
261- and Graph classes
262- :return: Graph instance containg the information passed in input and
263- the computed edge_index
283+ :type pos: torch.Tensor or LabelTensor
284+ :param radius: The radius within which points are connected.
285+ :type radius: float
286+ :param kwargs: Additional keyword arguments to be passed to the
287+ `GraphBuilder` and ` Graph` constructors.
288+ :return: A ` Graph` instance containing the input information and the
289+ computed edge_index.
264290 :rtype: Graph
265291 """
266292 edge_index = cls .compute_radius_graph (pos , radius )
@@ -269,7 +295,8 @@ def __new__(cls, pos, radius, **kwargs):
269295 @staticmethod
270296 def compute_radius_graph (points , radius ):
271297 """
272- Computes a radius-based graph for a given set of points.
298+ Computes edge_index for a given set of points base on the radius.
299+ Each point is connected to all the points within the radius.
273300
274301 :param points: A tensor of shape (N, D) representing the positions of
275302 N points in D-dimensional space.
@@ -295,7 +322,7 @@ class KNNGraph(GraphBuilder):
295322 def __new__ (cls , pos , neighbours , ** kwargs ):
296323 """
297324 Creates a new instance of the Graph class using k-nearest neighbors
298- to compute edge_index .
325+ algorithm to define the edges .
299326
300327 :param pos: A tensor of shape (N, D) representing the positions of N
301328 points in D-dimensional space.
@@ -323,8 +350,9 @@ def compute_knn_graph(points, k):
323350 N points in D-dimensional space.
324351 :type points: torch.Tensor | LabelTensor
325352 :param int k: The number of nearest neighbors to find for each point.
326- :rtype torch.Tensor : A tensor of shape (2, E), where E is the number of
353+ :return : A tensor of shape (2, E), where E is the number of
327354 edges, representing the edge indices of the KNN graph.
355+ :rtype: torch.Tensor
328356 """
329357
330358 dist = torch .cdist (points , points , p = 2 )
@@ -343,6 +371,11 @@ class LabelBatch(Batch):
343371 def from_data_list (cls , data_list ):
344372 """
345373 Create a Batch object from a list of Data objects.
374+
375+ :param data_list: List of Data/Graph objects
376+ :type data_list: list[Data] | list[Graph]
377+ :return: A Batch object containing the data in the list
378+ :rtype: Batch
346379 """
347380 # Store the labels of Data/Graph objects (all data have the same labels)
348381 # If the data do not contain labels, labels is an empty dictionary,
0 commit comments