Skip to content

Commit a55c6e3

Browse files
committed
Update network.py
1 parent 83ff09e commit a55c6e3

1 file changed

Lines changed: 5 additions & 1 deletion

File tree

deepmd/dpmodel/utils/network.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1076,7 +1076,11 @@ def get_graph_index(
10761076
# edge(ij) to angle(ijk) index_select; angle(ijk) to edge(ij) aggregate
10771077
edge_id = xp.arange(n_edge, dtype=nlist.dtype)
10781078
edge_index = xp.zeros((nf, nloc, nnei), dtype=nlist.dtype)
1079-
edge_index[xp.astype(nlist_mask, xp.bool)] = edge_id
1079+
if array_api_compat.is_jax_array(nlist):
1080+
# JAX doesn't support in-place item assignment
1081+
edge_index = edge_index.at[xp.astype(nlist_mask, xp.bool)].set(edge_id)
1082+
else:
1083+
edge_index[xp.astype(nlist_mask, xp.bool)] = edge_id
10801084
# only cut a_nnei neighbors, to avoid nnei x nnei
10811085
edge_index = edge_index[:, :, :a_nnei]
10821086
edge_index_ij = xp.broadcast_to(

0 commit comments

Comments
 (0)