@@ -117,7 +117,9 @@ def build_neighbor_list(
117117 assert list (diff .shape ) == [batch_size , nloc , nall , 3 ]
118118 rr = xp .linalg .vector_norm (diff , axis = - 1 )
119119 # if central atom has two zero distances, sorting sometimes can not exclude itself
120- rr -= xp .eye (nloc , nall , dtype = diff .dtype )[xp .newaxis , :, :]
120+ rr -= xp .eye (nloc , nall , dtype = diff .dtype , device = array_api_compat .device (diff ))[
121+ xp .newaxis , :, :
122+ ]
121123 nlist = xp .argsort (rr , axis = - 1 )
122124 rr = xp .sort (rr , axis = - 1 )
123125 rr = rr [:, :, 1 :]
@@ -128,11 +130,26 @@ def build_neighbor_list(
128130 nlist = nlist [:, :, :nsel ]
129131 else :
130132 rr = xp .concatenate (
131- [rr , xp .ones ([batch_size , nloc , nsel - nnei ], dtype = rr .dtype ) + rcut ],
133+ [
134+ rr ,
135+ xp .ones (
136+ [batch_size , nloc , nsel - nnei ],
137+ dtype = rr .dtype ,
138+ device = array_api_compat .device (rr ),
139+ )
140+ + rcut ,
141+ ],
132142 axis = - 1 ,
133143 )
134144 nlist = xp .concatenate (
135- [nlist , xp .ones ([batch_size , nloc , nsel - nnei ], dtype = nlist .dtype )],
145+ [
146+ nlist ,
147+ xp .ones (
148+ [batch_size , nloc , nsel - nnei ],
149+ dtype = nlist .dtype ,
150+ device = array_api_compat .device (nlist ),
151+ ),
152+ ],
136153 axis = - 1 ,
137154 )
138155 assert list (nlist .shape ) == [batch_size , nloc , nsel ]
@@ -218,7 +235,11 @@ def build_multiple_neighbor_list(
218235 return {}
219236 nb , nloc , nsel = nlist .shape
220237 if nsel < nsels [- 1 ]:
221- pad = - 1 * xp .ones ((nb , nloc , nsels [- 1 ] - nsel ), dtype = nlist .dtype )
238+ pad = - 1 * xp .ones (
239+ (nb , nloc , nsels [- 1 ] - nsel ),
240+ dtype = nlist .dtype ,
241+ device = array_api_compat .device (nlist ),
242+ )
222243 nlist = xp .concat ([nlist , pad ], axis = - 1 )
223244 nsel = nsels [- 1 ]
224245 coord1 = xp .reshape (coord , (nb , - 1 , 3 ))
@@ -276,7 +297,12 @@ def extend_coord_with_ghosts(
276297 xp = array_api_compat .array_namespace (coord , atype )
277298 nf , nloc = atype .shape
278299 # int64 for index
279- aidx = xp .tile (xp .arange (nloc , dtype = xp .int64 )[xp .newaxis , :], (nf , 1 ))
300+ aidx = xp .tile (
301+ xp .arange (nloc , dtype = xp .int64 , device = array_api_compat .device (atype ))[
302+ xp .newaxis , :
303+ ],
304+ (nf , 1 ),
305+ )
280306 if cell is None :
281307 nall = nloc
282308 extend_coord = coord
@@ -288,17 +314,41 @@ def extend_coord_with_ghosts(
288314 to_face = to_face_distance (cell )
289315 nbuff = xp .astype (xp .ceil (rcut / to_face ), xp .int64 )
290316 nbuff = xp .max (nbuff , axis = 0 )
291- xi = xp .arange (- int (nbuff [0 ]), int (nbuff [0 ]) + 1 , 1 , dtype = xp .int64 )
292- yi = xp .arange (- int (nbuff [1 ]), int (nbuff [1 ]) + 1 , 1 , dtype = xp .int64 )
293- zi = xp .arange (- int (nbuff [2 ]), int (nbuff [2 ]) + 1 , 1 , dtype = xp .int64 )
294- xyz = xp .linalg .outer (xi , xp .asarray ([1 , 0 , 0 ]))[:, xp .newaxis , xp .newaxis , :]
317+ xi = xp .arange (
318+ - int (nbuff [0 ]),
319+ int (nbuff [0 ]) + 1 ,
320+ 1 ,
321+ dtype = xp .int64 ,
322+ device = array_api_compat .device (coord ),
323+ )
324+ yi = xp .arange (
325+ - int (nbuff [1 ]),
326+ int (nbuff [1 ]) + 1 ,
327+ 1 ,
328+ dtype = xp .int64 ,
329+ device = array_api_compat .device (coord ),
330+ )
331+ zi = xp .arange (
332+ - int (nbuff [2 ]),
333+ int (nbuff [2 ]) + 1 ,
334+ 1 ,
335+ dtype = xp .int64 ,
336+ device = array_api_compat .device (coord ),
337+ )
338+ xyz = xp .linalg .outer (
339+ xi , xp .asarray ([1 , 0 , 0 ], device = array_api_compat .device (xi ))
340+ )[:, xp .newaxis , xp .newaxis , :]
295341 xyz = (
296342 xyz
297- + xp .linalg .outer (yi , xp .asarray ([0 , 1 , 0 ]))[xp .newaxis , :, xp .newaxis , :]
343+ + xp .linalg .outer (
344+ yi , xp .asarray ([0 , 1 , 0 ], device = array_api_compat .device (yi ))
345+ )[xp .newaxis , :, xp .newaxis , :]
298346 )
299347 xyz = (
300348 xyz
301- + xp .linalg .outer (zi , xp .asarray ([0 , 0 , 1 ]))[xp .newaxis , xp .newaxis , :, :]
349+ + xp .linalg .outer (
350+ zi , xp .asarray ([0 , 0 , 1 ], device = array_api_compat .device (zi ))
351+ )[xp .newaxis , xp .newaxis , :, :]
302352 )
303353 xyz = xp .reshape (xyz , (- 1 , 3 ))
304354 xyz = xp .astype (xyz , coord .dtype )
0 commit comments