We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 83ff09e commit a55c6e3Copy full SHA for a55c6e3
1 file changed
deepmd/dpmodel/utils/network.py
@@ -1076,7 +1076,11 @@ def get_graph_index(
1076
# edge(ij) to angle(ijk) index_select; angle(ijk) to edge(ij) aggregate
1077
edge_id = xp.arange(n_edge, dtype=nlist.dtype)
1078
edge_index = xp.zeros((nf, nloc, nnei), dtype=nlist.dtype)
1079
- edge_index[xp.astype(nlist_mask, xp.bool)] = edge_id
+ if array_api_compat.is_jax_array(nlist):
1080
+ # JAX doesn't support in-place item assignment
1081
+ edge_index = edge_index.at[xp.astype(nlist_mask, xp.bool)].set(edge_id)
1082
+ else:
1083
+ edge_index[xp.astype(nlist_mask, xp.bool)] = edge_id
1084
# only cut a_nnei neighbors, to avoid nnei x nnei
1085
edge_index = edge_index[:, :, :a_nnei]
1086
edge_index_ij = xp.broadcast_to(
0 commit comments