@@ -67,8 +67,9 @@ def test_build_graph(x, pos):
6767 ),
6868 ],
6969)
70- def test_build_radius_graph (x , pos ):
71- graph = RadiusGraph (x = x , pos = pos , radius = 0.5 )
70+ @pytest .mark .parametrize ("loop" , [True , False ])
71+ def test_build_radius_graph (x , pos , loop ):
72+ graph = RadiusGraph (x = x , pos = pos , radius = 0.5 , loop = loop )
7273 assert hasattr (graph , "x" )
7374 assert hasattr (graph , "pos" )
7475 assert hasattr (graph , "edge_index" )
@@ -84,6 +85,15 @@ def test_build_radius_graph(x, pos):
8485 assert graph .pos .labels == pos .labels
8586 else :
8687 assert isinstance (graph .pos , torch .Tensor )
88+ if not loop :
89+ assert (
90+ len (
91+ torch .nonzero (
92+ graph .edge_index [0 ] == graph .edge_index [1 ], as_tuple = True
93+ )[0 ]
94+ )
95+ == 0
96+ ) # Detect self loops
8797
8898
8999@pytest .mark .parametrize (
@@ -168,8 +178,9 @@ def test_build_radius_graph_custom_edge_attr(x, pos):
168178 ),
169179 ],
170180)
171- def test_build_knn_graph (x , pos ):
172- graph = KNNGraph (x = x , pos = pos , neighbours = 2 )
181+ @pytest .mark .parametrize ("loop" , [True , False ])
182+ def test_build_knn_graph (x , pos , loop ):
183+ graph = KNNGraph (x = x , pos = pos , neighbours = 2 , loop = loop )
173184 assert hasattr (graph , "x" )
174185 assert hasattr (graph , "pos" )
175186 assert hasattr (graph , "edge_index" )
@@ -186,6 +197,15 @@ def test_build_knn_graph(x, pos):
186197 else :
187198 assert isinstance (graph .pos , torch .Tensor )
188199 assert graph .edge_attr is None
200+ self_loops = len (
201+ torch .nonzero (
202+ graph .edge_index [0 ] == graph .edge_index [1 ], as_tuple = True
203+ )[0 ]
204+ )
205+ if loop :
206+ assert self_loops != 0
207+ else :
208+ assert self_loops == 0
189209
190210
191211@pytest .mark .parametrize (
0 commit comments