Skip to content

Commit ebf90e3

Browse files
author
Han Wang
committed
support ase nlist
1 parent a31fd18 commit ebf90e3

2 files changed

Lines changed: 402 additions & 20 deletions

File tree

deepmd/pt_expt/infer/deep_eval.py

Lines changed: 228 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from deepmd.dpmodel.utils.nlist import (
2828
build_neighbor_list,
2929
extend_coord_with_ghosts,
30+
nlist_distinguish_types,
3031
)
3132
from deepmd.dpmodel.utils.region import (
3233
normalize_coord,
@@ -115,6 +116,7 @@ def __init__(
115116
) -> None:
116117
self.output_def = output_def
117118
self.model_path = model_file
119+
self.neighbor_list = neighbor_list
118120

119121
# Load the exported program with metadata
120122
extra_files = {"model_def_script.json": ""}
@@ -310,36 +312,38 @@ def _get_natoms_and_nframes(
310312
nframes = coords.shape[0]
311313
return natoms, nframes
312314

313-
def _eval_model(
315+
def _build_nlist_native(
314316
self,
315317
coords: np.ndarray,
316318
cells: np.ndarray | None,
317319
atom_types: np.ndarray,
318-
fparam: np.ndarray | None,
319-
aparam: np.ndarray | None,
320-
request_defs: list[OutputVariableDef],
321-
) -> tuple[np.ndarray, ...]:
322-
nframes = coords.shape[0]
323-
if len(atom_types.shape) == 1:
324-
natoms = len(atom_types)
325-
atom_types = np.tile(atom_types, nframes).reshape(nframes, -1)
326-
else:
327-
natoms = len(atom_types[0])
320+
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
321+
"""Build extended coords, atype, nlist, mapping using native nlist.
328322
323+
Parameters
324+
----------
325+
coords : np.ndarray
326+
Coordinates, shape (nframes, natoms, 3).
327+
cells : np.ndarray or None
328+
Cell vectors, shape (nframes, 9). None for non-PBC.
329+
atom_types : np.ndarray
330+
Atom types, shape (nframes, natoms).
331+
332+
Returns
333+
-------
334+
extended_coord, extended_atype, nlist, mapping
335+
"""
336+
nframes = coords.shape[0]
337+
natoms = coords.shape[1]
329338
rcut = self.rcut
330339
sel = self.metadata["sel"]
331340
mixed_types = self.metadata["mixed_types"]
332341

333-
coord_input = coords.reshape(nframes, natoms, 3)
334342
if cells is not None:
335343
box_input = cells.reshape(nframes, 3, 3)
344+
coord_normalized = normalize_coord(coords, box_input)
336345
else:
337-
box_input = None
338-
339-
if box_input is not None:
340-
coord_normalized = normalize_coord(coord_input, box_input)
341-
else:
342-
coord_normalized = coord_input
346+
coord_normalized = coords
343347

344348
extended_coord, extended_atype, mapping = extend_coord_with_ghosts(
345349
coord_normalized,
@@ -356,6 +360,212 @@ def _eval_model(
356360
distinguish_types=not mixed_types,
357361
)
358362
extended_coord = extended_coord.reshape(nframes, -1, 3)
363+
return extended_coord, extended_atype, nlist, mapping
364+
365+
def _build_nlist_ase(
366+
self,
367+
coords: np.ndarray,
368+
cells: np.ndarray | None,
369+
atom_types: np.ndarray,
370+
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
371+
"""Build extended coords, atype, nlist, mapping using ASE neighbor list.
372+
373+
Handles multiple frames by building per frame and padding to
374+
a common nall.
375+
376+
Parameters
377+
----------
378+
coords : np.ndarray
379+
Coordinates, shape (nframes, natoms, 3).
380+
cells : np.ndarray or None
381+
Cell vectors, shape (nframes, 9). None for non-PBC.
382+
atom_types : np.ndarray
383+
Atom types, shape (nframes, natoms).
384+
385+
Returns
386+
-------
387+
extended_coord, extended_atype, nlist, mapping
388+
"""
389+
nframes = coords.shape[0]
390+
frame_results = []
391+
for ff in range(nframes):
392+
ec, ea, nl, mp = self._build_nlist_ase_single(
393+
coords[ff],
394+
cells[ff] if cells is not None else None,
395+
atom_types[ff],
396+
)
397+
frame_results.append((ec, ea, nl, mp))
398+
# Pad to max nall across frames
399+
max_nall = max(ec.shape[0] for ec, _, _, _ in frame_results)
400+
ext_coords, ext_atypes, nlists, mappings = [], [], [], []
401+
for ec, ea, nl, mp in frame_results:
402+
pad = max_nall - ec.shape[0]
403+
if pad > 0:
404+
ec = np.concatenate(
405+
[ec, np.zeros((pad, 3), dtype=ec.dtype)],
406+
axis=0,
407+
)
408+
ea = np.concatenate(
409+
[ea, np.full(pad, -1, dtype=ea.dtype)],
410+
axis=0,
411+
)
412+
mp = np.concatenate(
413+
[mp, np.zeros(pad, dtype=mp.dtype)],
414+
axis=0,
415+
)
416+
ext_coords.append(ec)
417+
ext_atypes.append(ea)
418+
nlists.append(nl)
419+
mappings.append(mp)
420+
return (
421+
np.stack(ext_coords, axis=0),
422+
np.stack(ext_atypes, axis=0),
423+
np.stack(nlists, axis=0),
424+
np.stack(mappings, axis=0),
425+
)
426+
427+
def _build_nlist_ase_single(
428+
self,
429+
positions: np.ndarray,
430+
cell: np.ndarray | None,
431+
atype: np.ndarray,
432+
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
433+
"""Build extended coords, atype, nlist, mapping for a single frame.
434+
435+
Parameters
436+
----------
437+
positions : np.ndarray
438+
Atom positions, shape (natoms, 3).
439+
cell : np.ndarray or None
440+
Cell vector, shape (9,). None for non-PBC.
441+
atype : np.ndarray
442+
Atom types, shape (natoms,).
443+
444+
Returns
445+
-------
446+
extended_coord : np.ndarray, shape (nall, 3)
447+
extended_atype : np.ndarray, shape (nall,)
448+
nlist : np.ndarray, shape (nloc, nsel)
449+
mapping : np.ndarray, shape (nall,)
450+
"""
451+
sel = self.metadata["sel"]
452+
mixed_types = self.metadata["mixed_types"]
453+
nsel = sum(sel)
454+
455+
natoms = positions.shape[0]
456+
cell_3x3 = (
457+
cell.reshape(3, 3)
458+
if cell is not None
459+
else np.zeros((3, 3), dtype=np.float64)
460+
)
461+
pbc = np.repeat(cell is not None, 3)
462+
463+
nl = self.neighbor_list
464+
nl.bothways = True
465+
nl.self_interaction = False
466+
if nl.update(pbc, cell_3x3, positions):
467+
nl.build(pbc, cell_3x3, positions)
468+
469+
first_neigh = nl.first_neigh.copy()
470+
pair_second = nl.pair_second.copy()
471+
offset_vec = nl.offset_vec.copy()
472+
473+
# Identify ghost atoms (out-of-box neighbors)
474+
out_mask = np.any(offset_vec != 0, axis=1)
475+
out_idx = pair_second[out_mask]
476+
out_offset = offset_vec[out_mask]
477+
out_coords = positions[out_idx] + out_offset.dot(cell_3x3)
478+
out_atype = atype[out_idx]
479+
480+
nloc = natoms
481+
nghost = out_idx.size
482+
483+
# Extended arrays (no leading frame dimension)
484+
extended_coord = np.concatenate((positions, out_coords), axis=0)
485+
extended_atype = np.concatenate((atype, out_atype))
486+
mapping = np.concatenate(
487+
(np.arange(nloc, dtype=np.int32), out_idx.astype(np.int32))
488+
)
489+
490+
# Remap neighbor indices: ghost atoms get new indices [nloc, nloc+nghost)
491+
ghost_remap = pair_second.copy()
492+
ghost_remap[out_mask] = np.arange(nloc, nloc + nghost, dtype=np.int64)
493+
494+
# Build nlist: vectorized CSR-to-dense conversion
495+
rcut = self.rcut
496+
counts = np.diff(first_neigh)
497+
max_nn = int(counts.max()) if counts.size > 0 else 0
498+
499+
# CSR to dense: (nloc, max_nn) neighbor index array, padded with -1
500+
col_idx = np.arange(len(ghost_remap), dtype=np.int64) - np.repeat(
501+
first_neigh[:-1], counts
502+
)
503+
row_idx = np.repeat(np.arange(nloc, dtype=np.int64), counts)
504+
dense_idx = np.full((nloc, max_nn), -1, dtype=np.int64)
505+
dense_idx[row_idx, col_idx] = ghost_remap
506+
507+
# Compute all distances at once
508+
valid = dense_idx >= 0
509+
lookup = np.where(valid, dense_idx, 0)
510+
neigh_coords = extended_coord[lookup] # (nloc, max_nn, 3)
511+
dists = np.linalg.norm(
512+
neigh_coords - positions[:, None, :], axis=-1
513+
) # (nloc, max_nn)
514+
515+
# Mask invalid and out-of-range, sort by distance
516+
valid &= dists <= rcut
517+
dists = np.where(valid, dists, np.inf)
518+
order = np.argsort(dists, axis=-1)
519+
sorted_idx = np.take_along_axis(dense_idx, order, axis=-1)
520+
sorted_valid = np.take_along_axis(valid, order, axis=-1)
521+
522+
# Take first nsel neighbors, pad if fewer than nsel
523+
if max_nn >= nsel:
524+
nlist = sorted_idx[:, :nsel]
525+
nlist = np.where(sorted_valid[:, :nsel], nlist, -1)
526+
else:
527+
nlist = np.full((nloc, nsel), -1, dtype=np.int64)
528+
nlist[:, :max_nn] = np.where(sorted_valid, sorted_idx, -1)
529+
530+
if not mixed_types:
531+
# nlist_distinguish_types expects (nframes, nloc, nsel)
532+
nlist = nlist_distinguish_types(
533+
nlist[None],
534+
extended_atype[None],
535+
sel,
536+
)[0]
537+
538+
return extended_coord, extended_atype, nlist, mapping
539+
540+
def _eval_model(
541+
self,
542+
coords: np.ndarray,
543+
cells: np.ndarray | None,
544+
atom_types: np.ndarray,
545+
fparam: np.ndarray | None,
546+
aparam: np.ndarray | None,
547+
request_defs: list[OutputVariableDef],
548+
) -> tuple[np.ndarray, ...]:
549+
nframes = coords.shape[0]
550+
if len(atom_types.shape) == 1:
551+
natoms = len(atom_types)
552+
atom_types = np.tile(atom_types, nframes).reshape(nframes, -1)
553+
else:
554+
natoms = len(atom_types[0])
555+
556+
coord_input = coords.reshape(nframes, natoms, 3)
557+
if self.neighbor_list is not None:
558+
extended_coord, extended_atype, nlist, mapping = self._build_nlist_ase(
559+
coord_input,
560+
cells,
561+
atom_types,
562+
)
563+
else:
564+
extended_coord, extended_atype, nlist, mapping = self._build_nlist_native(
565+
coord_input,
566+
cells,
567+
atom_types,
568+
)
359569

360570
# Convert to torch tensors
361571
from deepmd.pt_expt.utils.env import (

0 commit comments

Comments
 (0)