@@ -578,7 +578,8 @@ def call(
578578 # n_angle x 1
579579 a_sw = (a_sw [:, :, :, None ] * a_sw [:, :, None , :])[a_nlist_mask ]
580580 else :
581- edge_index = angle_index = xp .zeros ([1 , 3 ], dtype = nlist .dtype )
581+ edge_index = xp .zeros ([2 , 1 ], dtype = nlist .dtype )
582+ angle_index = xp .zeros ([3 , 1 ], dtype = nlist .dtype )
582583
583584 # get edge and angle embedding
584585 # nb x nloc x nnei x e_dim [OR] n_edge x e_dim
@@ -622,7 +623,7 @@ def call(
622623 edge_ebd ,
623624 h2 ,
624625 sw ,
625- owner = edge_index [:, 0 ],
626+ owner = edge_index [0 ],
626627 num_owner = nframes * nloc ,
627628 nb = nframes ,
628629 nloc = nloc ,
@@ -1286,8 +1287,8 @@ def call(
12861287 a_nlist : np .ndarray , # nf x nloc x a_nnei
12871288 a_nlist_mask : np .ndarray , # nf x nloc x a_nnei
12881289 a_sw : np .ndarray , # switch func, nf x nloc x a_nnei
1289- edge_index : np .ndarray , # n_edge x 2
1290- angle_index : np .ndarray , # n_angle x 3
1290+ edge_index : np .ndarray , # 2 x n_edge
1291+ angle_index : np .ndarray , # 3 x n_angle
12911292 ):
12921293 """
12931294 Parameters
@@ -1312,12 +1313,12 @@ def call(
13121313 Masks of the neighbor list for angle. real nei 1 otherwise 0
13131314 a_sw : nf x nloc x a_nnei
13141315 Switch function for angle.
1315- edge_index : Optional for dynamic sel, n_edge x 2
1316+ edge_index : Optional for dynamic sel, 2 x n_edge
13161317 n2e_index : n_edge
13171318 Broadcast indices from node(i) to edge(ij), or reduction indices from edge(ij) to node(i).
13181319 n_ext2e_index : n_edge
13191320 Broadcast indices from extended node(j) to edge(ij).
1320- angle_index : Optional for dynamic sel, n_angle x 3
1321+ angle_index : Optional for dynamic sel, 3 x n_angle
13211322 n2a_index : n_angle
13221323 Broadcast indices from extended node(j) to angle(ijk).
13231324 eij2a_index : n_angle
@@ -1362,11 +1363,11 @@ def call(
13621363 assert (n_edge , 3 ) == h2 .shape
13631364 del a_nlist # may be used in the future
13641365
1365- n2e_index , n_ext2e_index = edge_index [:, 0 ], edge_index [:, 1 ]
1366+ n2e_index , n_ext2e_index = edge_index [0 ], edge_index [1 ]
13661367 n2a_index , eij2a_index , eik2a_index = (
1367- angle_index [:, 0 ],
1368- angle_index [:, 1 ],
1369- angle_index [:, 2 ],
1368+ angle_index [0 ],
1369+ angle_index [1 ],
1370+ angle_index [2 ],
13701371 )
13711372
13721373 # nb x nloc x nnei x n_dim [OR] n_edge x n_dim
0 commit comments