2727from deepmd .dpmodel .utils .nlist import (
2828 build_neighbor_list ,
2929 extend_coord_with_ghosts ,
30+ nlist_distinguish_types ,
3031)
3132from deepmd .dpmodel .utils .region import (
3233 normalize_coord ,
@@ -115,6 +116,7 @@ def __init__(
115116 ) -> None :
116117 self .output_def = output_def
117118 self .model_path = model_file
119+ self .neighbor_list = neighbor_list
118120
119121 # Load the exported program with metadata
120122 extra_files = {"model_def_script.json" : "" }
@@ -310,36 +312,38 @@ def _get_natoms_and_nframes(
310312 nframes = coords .shape [0 ]
311313 return natoms , nframes
312314
313- def _eval_model (
315+ def _build_nlist_native (
314316 self ,
315317 coords : np .ndarray ,
316318 cells : np .ndarray | None ,
317319 atom_types : np .ndarray ,
318- fparam : np .ndarray | None ,
319- aparam : np .ndarray | None ,
320- request_defs : list [OutputVariableDef ],
321- ) -> tuple [np .ndarray , ...]:
322- nframes = coords .shape [0 ]
323- if len (atom_types .shape ) == 1 :
324- natoms = len (atom_types )
325- atom_types = np .tile (atom_types , nframes ).reshape (nframes , - 1 )
326- else :
327- natoms = len (atom_types [0 ])
320+ ) -> tuple [np .ndarray , np .ndarray , np .ndarray , np .ndarray ]:
321+ """Build extended coords, atype, nlist, mapping using native nlist.
328322
323+ Parameters
324+ ----------
325+ coords : np.ndarray
326+ Coordinates, shape (nframes, natoms, 3).
327+ cells : np.ndarray or None
328+ Cell vectors, shape (nframes, 9). None for non-PBC.
329+ atom_types : np.ndarray
330+ Atom types, shape (nframes, natoms).
331+
332+ Returns
333+ -------
334+ extended_coord, extended_atype, nlist, mapping
335+ """
336+ nframes = coords .shape [0 ]
337+ natoms = coords .shape [1 ]
329338 rcut = self .rcut
330339 sel = self .metadata ["sel" ]
331340 mixed_types = self .metadata ["mixed_types" ]
332341
333- coord_input = coords .reshape (nframes , natoms , 3 )
334342 if cells is not None :
335343 box_input = cells .reshape (nframes , 3 , 3 )
344+ coord_normalized = normalize_coord (coords , box_input )
336345 else :
337- box_input = None
338-
339- if box_input is not None :
340- coord_normalized = normalize_coord (coord_input , box_input )
341- else :
342- coord_normalized = coord_input
346+ coord_normalized = coords
343347
344348 extended_coord , extended_atype , mapping = extend_coord_with_ghosts (
345349 coord_normalized ,
@@ -356,6 +360,212 @@ def _eval_model(
356360 distinguish_types = not mixed_types ,
357361 )
358362 extended_coord = extended_coord .reshape (nframes , - 1 , 3 )
363+ return extended_coord , extended_atype , nlist , mapping
364+
365+ def _build_nlist_ase (
366+ self ,
367+ coords : np .ndarray ,
368+ cells : np .ndarray | None ,
369+ atom_types : np .ndarray ,
370+ ) -> tuple [np .ndarray , np .ndarray , np .ndarray , np .ndarray ]:
371+ """Build extended coords, atype, nlist, mapping using ASE neighbor list.
372+
373+ Handles multiple frames by building per frame and padding to
374+ a common nall.
375+
376+ Parameters
377+ ----------
378+ coords : np.ndarray
379+ Coordinates, shape (nframes, natoms, 3).
380+ cells : np.ndarray or None
381+ Cell vectors, shape (nframes, 9). None for non-PBC.
382+ atom_types : np.ndarray
383+ Atom types, shape (nframes, natoms).
384+
385+ Returns
386+ -------
387+ extended_coord, extended_atype, nlist, mapping
388+ """
389+ nframes = coords .shape [0 ]
390+ frame_results = []
391+ for ff in range (nframes ):
392+ ec , ea , nl , mp = self ._build_nlist_ase_single (
393+ coords [ff ],
394+ cells [ff ] if cells is not None else None ,
395+ atom_types [ff ],
396+ )
397+ frame_results .append ((ec , ea , nl , mp ))
398+ # Pad to max nall across frames
399+ max_nall = max (ec .shape [0 ] for ec , _ , _ , _ in frame_results )
400+ ext_coords , ext_atypes , nlists , mappings = [], [], [], []
401+ for ec , ea , nl , mp in frame_results :
402+ pad = max_nall - ec .shape [0 ]
403+ if pad > 0 :
404+ ec = np .concatenate (
405+ [ec , np .zeros ((pad , 3 ), dtype = ec .dtype )],
406+ axis = 0 ,
407+ )
408+ ea = np .concatenate (
409+ [ea , np .full (pad , - 1 , dtype = ea .dtype )],
410+ axis = 0 ,
411+ )
412+ mp = np .concatenate (
413+ [mp , np .zeros (pad , dtype = mp .dtype )],
414+ axis = 0 ,
415+ )
416+ ext_coords .append (ec )
417+ ext_atypes .append (ea )
418+ nlists .append (nl )
419+ mappings .append (mp )
420+ return (
421+ np .stack (ext_coords , axis = 0 ),
422+ np .stack (ext_atypes , axis = 0 ),
423+ np .stack (nlists , axis = 0 ),
424+ np .stack (mappings , axis = 0 ),
425+ )
426+
427+ def _build_nlist_ase_single (
428+ self ,
429+ positions : np .ndarray ,
430+ cell : np .ndarray | None ,
431+ atype : np .ndarray ,
432+ ) -> tuple [np .ndarray , np .ndarray , np .ndarray , np .ndarray ]:
433+ """Build extended coords, atype, nlist, mapping for a single frame.
434+
435+ Parameters
436+ ----------
437+ positions : np.ndarray
438+ Atom positions, shape (natoms, 3).
439+ cell : np.ndarray or None
440+ Cell vector, shape (9,). None for non-PBC.
441+ atype : np.ndarray
442+ Atom types, shape (natoms,).
443+
444+ Returns
445+ -------
446+ extended_coord : np.ndarray, shape (nall, 3)
447+ extended_atype : np.ndarray, shape (nall,)
448+ nlist : np.ndarray, shape (nloc, nsel)
449+ mapping : np.ndarray, shape (nall,)
450+ """
451+ sel = self .metadata ["sel" ]
452+ mixed_types = self .metadata ["mixed_types" ]
453+ nsel = sum (sel )
454+
455+ natoms = positions .shape [0 ]
456+ cell_3x3 = (
457+ cell .reshape (3 , 3 )
458+ if cell is not None
459+ else np .zeros ((3 , 3 ), dtype = np .float64 )
460+ )
461+ pbc = np .repeat (cell is not None , 3 )
462+
463+ nl = self .neighbor_list
464+ nl .bothways = True
465+ nl .self_interaction = False
466+ if nl .update (pbc , cell_3x3 , positions ):
467+ nl .build (pbc , cell_3x3 , positions )
468+
469+ first_neigh = nl .first_neigh .copy ()
470+ pair_second = nl .pair_second .copy ()
471+ offset_vec = nl .offset_vec .copy ()
472+
473+ # Identify ghost atoms (out-of-box neighbors)
474+ out_mask = np .any (offset_vec != 0 , axis = 1 )
475+ out_idx = pair_second [out_mask ]
476+ out_offset = offset_vec [out_mask ]
477+ out_coords = positions [out_idx ] + out_offset .dot (cell_3x3 )
478+ out_atype = atype [out_idx ]
479+
480+ nloc = natoms
481+ nghost = out_idx .size
482+
483+ # Extended arrays (no leading frame dimension)
484+ extended_coord = np .concatenate ((positions , out_coords ), axis = 0 )
485+ extended_atype = np .concatenate ((atype , out_atype ))
486+ mapping = np .concatenate (
487+ (np .arange (nloc , dtype = np .int32 ), out_idx .astype (np .int32 ))
488+ )
489+
490+ # Remap neighbor indices: ghost atoms get new indices [nloc, nloc+nghost)
491+ ghost_remap = pair_second .copy ()
492+ ghost_remap [out_mask ] = np .arange (nloc , nloc + nghost , dtype = np .int64 )
493+
494+ # Build nlist: vectorized CSR-to-dense conversion
495+ rcut = self .rcut
496+ counts = np .diff (first_neigh )
497+ max_nn = int (counts .max ()) if counts .size > 0 else 0
498+
499+ # CSR to dense: (nloc, max_nn) neighbor index array, padded with -1
500+ col_idx = np .arange (len (ghost_remap ), dtype = np .int64 ) - np .repeat (
501+ first_neigh [:- 1 ], counts
502+ )
503+ row_idx = np .repeat (np .arange (nloc , dtype = np .int64 ), counts )
504+ dense_idx = np .full ((nloc , max_nn ), - 1 , dtype = np .int64 )
505+ dense_idx [row_idx , col_idx ] = ghost_remap
506+
507+ # Compute all distances at once
508+ valid = dense_idx >= 0
509+ lookup = np .where (valid , dense_idx , 0 )
510+ neigh_coords = extended_coord [lookup ] # (nloc, max_nn, 3)
511+ dists = np .linalg .norm (
512+ neigh_coords - positions [:, None , :], axis = - 1
513+ ) # (nloc, max_nn)
514+
515+ # Mask invalid and out-of-range, sort by distance
516+ valid &= dists <= rcut
517+ dists = np .where (valid , dists , np .inf )
518+ order = np .argsort (dists , axis = - 1 )
519+ sorted_idx = np .take_along_axis (dense_idx , order , axis = - 1 )
520+ sorted_valid = np .take_along_axis (valid , order , axis = - 1 )
521+
522+ # Take first nsel neighbors, pad if fewer than nsel
523+ if max_nn >= nsel :
524+ nlist = sorted_idx [:, :nsel ]
525+ nlist = np .where (sorted_valid [:, :nsel ], nlist , - 1 )
526+ else :
527+ nlist = np .full ((nloc , nsel ), - 1 , dtype = np .int64 )
528+ nlist [:, :max_nn ] = np .where (sorted_valid , sorted_idx , - 1 )
529+
530+ if not mixed_types :
531+ # nlist_distinguish_types expects (nframes, nloc, nsel)
532+ nlist = nlist_distinguish_types (
533+ nlist [None ],
534+ extended_atype [None ],
535+ sel ,
536+ )[0 ]
537+
538+ return extended_coord , extended_atype , nlist , mapping
539+
540+ def _eval_model (
541+ self ,
542+ coords : np .ndarray ,
543+ cells : np .ndarray | None ,
544+ atom_types : np .ndarray ,
545+ fparam : np .ndarray | None ,
546+ aparam : np .ndarray | None ,
547+ request_defs : list [OutputVariableDef ],
548+ ) -> tuple [np .ndarray , ...]:
549+ nframes = coords .shape [0 ]
550+ if len (atom_types .shape ) == 1 :
551+ natoms = len (atom_types )
552+ atom_types = np .tile (atom_types , nframes ).reshape (nframes , - 1 )
553+ else :
554+ natoms = len (atom_types [0 ])
555+
556+ coord_input = coords .reshape (nframes , natoms , 3 )
557+ if self .neighbor_list is not None :
558+ extended_coord , extended_atype , nlist , mapping = self ._build_nlist_ase (
559+ coord_input ,
560+ cells ,
561+ atom_types ,
562+ )
563+ else :
564+ extended_coord , extended_atype , nlist , mapping = self ._build_nlist_native (
565+ coord_input ,
566+ cells ,
567+ atom_types ,
568+ )
359569
360570 # Convert to torch tensors
361571 from deepmd .pt_expt .utils .env import (
0 commit comments