Skip to content

Commit ccaa551

Browse files
authored
feat(pt): Add DPA4/SeZM descriptor & model 🎉🎉🎉 (#5448)
## Summary This PR adds PyTorch support for DPA4, the DeePMD-kit implementation of SeZM (Smooth Equivariant Zone-bridging Model). It introduces the DPA4/SeZM model, descriptor, fitting network, training integration, export path, documentation, examples, and tests. ## Main Changes - Add the DPA4/SeZM PyTorch model stack: - `model.type: "dpa4"` / `"sezm"` - `descriptor.type: "dpa4"` / `"sezm"` - `fitting_net.type: "dpa4_ener"` / `"sezm_ener"` - Implement the SO(3)-equivariant descriptor with edge-local SO(2) convolutions, angular schedules, smooth radial envelopes, attention/focus streams, and environment-seeded initial features. - Add zone-bridging support for short-range analytical repulsion, including ZBL coupling and descriptor-side short-range clamping. - Add DPA4 training support for: - conservative energy/force training through `loss.type: "ener"` - experimental direct-force denoising through `loss.type: "dens"` - spin models in the PyTorch backend - shared-fitting multitask case FiLM conditioning - LoRA fine-tuning and merged checkpoint export - Add the DPA4 `.pt2` freeze/export path using AOTInductor for checkpoints that cannot be represented by the regular TorchScript freeze path. - Add CLI, argcheck, validation, data-system, and inference integration needed to route DPA4 configs and exported models correctly. - Add water examples for standard DPA4, ZBL bridging, spin, DeNS, multitask/shared-fitting, LoRA fine-tuning, and LAMMPS inference. - Add official model documentation at `doc/model/dpa4.md`. ## Tests This PR adds coverage for: - DPA4/SeZM model and descriptor construction - DPA4 aliases in model, descriptor, and fitting configuration - SO(3)/SO(2) equivariance behavior - conservative energy/force paths - `torch.compile` eager/compiled consistency - DPA4 `.pt2` export and DeepPot inference - spin model behavior - ZBL zone bridging - DeNS loss and direct-force mode - LoRA adapter injection, freezing, merging, and compile compatibility - optional Triton kernel dispatch and numerical consistency - supporting utility changes in neighbor-list, LMDB data, and distributed checks Relevant test files include: - `source/tests/pt/model/test_descriptor_sezm.py` - `source/tests/pt/model/test_descriptor_sezm_s2_equivariance.py` - `source/tests/pt/model/test_descriptor_sezm_triton.py` - `source/tests/pt/model/test_sezm_model.py` - `source/tests/pt/model/test_sezm_spin_model.py` - `source/tests/pt/model/test_sezm_export.py` - `source/tests/pt/test_training.py` - `source/tests/pt/test_train_utils.py` - `source/tests/common/dpmodel/test_dist_check.py` - `source/tests/common/dpmodel/test_lmdb_data.py` ## Notes DPA4 is currently implemented for the PyTorch backend. Model compression is not supported, and DPA4 checkpoints use the `.pt2` export path instead of the regular TorchScript freeze path. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added SeZM model family and DeNS denoising loss for training; new optimized ".pt2" export path with embedded metadata. * **Improvements** * LoRA fine-tuning workflow (apply/merge/strip) for lightweight adapters. * On-demand minimum pairwise-distance computation during data reads. * Better JAX neighbor-list handling and optional GPU/Triton-accelerated descriptor kernels for faster inference/training. <!-- review_stack_entry_start --> [![Review Change Stack](https://storage.googleapis.com/coderabbit_public_assets/review-stack-in-coderabbit-ui.svg)](https://app.coderabbit.ai/change-stack/deepmodeling/deepmd-kit/pull/5448?utm_source=github_walkthrough&utm_medium=github&utm_campaign=change_stack) <!-- review_stack_entry_end --> <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent 0ce2351 commit ccaa551

95 files changed

Lines changed: 34184 additions & 223 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.github/workflows/copilot-setup-steps.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ jobs:
5252
run: uv pip install --group pin_tensorflow_cpu --group pin_pytorch_cpu --torch-backend cpu
5353

5454
- name: Build Python package
55-
run: uv pip install -e .[cpu,test]
55+
run: uv pip install -e .[cpu,test,torch]
5656

5757
- name: Install prek tools
5858
run: uv tool install prek

.github/workflows/test_cc.yml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,10 @@ jobs:
4646
run: |
4747
source/install/uv_with_retry.sh pip install --system --group pin_tensorflow_cpu --group pin_pytorch_cpu --group pin_jax_cpu --torch-backend cpu
4848
export TENSORFLOW_ROOT=$(python -c 'import importlib.util,pathlib;print(pathlib.Path(importlib.util.find_spec("tensorflow").origin).parent)')
49-
source/install/uv_with_retry.sh pip install --system -e .[cpu,test,lmp,jax] mpi4py mpich
49+
export PYTORCH_ROOT=$(python -c 'import torch;print(torch.__path__[0])')
50+
source/install/uv_with_retry.sh pip install --system -e .[cpu,test,lmp,jax,torch] mpi4py mpich
51+
env:
52+
DP_ENABLE_PYTORCH: 1
5053
- name: Convert models
5154
run: source/tests/infer/convert-models.sh
5255
# https://github.com/actions/runner-images/issues/9491

.github/workflows/test_python.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ jobs:
3131
source/install/uv_with_retry.sh pip install --system openmpi --group pin_tensorflow_cpu --group pin_pytorch_cpu --torch-backend cpu
3232
export TENSORFLOW_ROOT=$(python -c 'import importlib.util,pathlib;print(pathlib.Path(importlib.util.find_spec("tensorflow").origin).parent)')
3333
export PYTORCH_ROOT=$(python -c 'import torch;print(torch.__path__[0])')
34-
source/install/uv_with_retry.sh pip install --system -e .[test,jax] mpi4py --group pin_jax_cpu
34+
source/install/uv_with_retry.sh pip install --system -e .[test,jax,torch] mpi4py --group pin_jax_cpu
3535
source/install/uv_with_retry.sh pip install --system --find-links "https://www.paddlepaddle.org.cn/packages/nightly/cpu/paddlepaddle/" --index-url https://pypi.org/simple --trusted-host www.paddlepaddle.org.cn --trusted-host paddlepaddle.org.cn paddlepaddle==3.4.0.dev20260310
3636
env:
3737
# Please note that uv has some issues with finding

backend/find_pytorch.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ def get_pt_requirement(pt_version: str = "") -> dict:
136136
if pt_version != ""
137137
# https://github.com/pytorch/pytorch/commit/7e0c26d4d80d6602aed95cb680dfc09c9ce533bc
138138
else "torch>=2.1.0",
139+
"e3nn>=0.5.9",
139140
*mpi_requirement,
140141
*cibw_requirement,
141142
],

deepmd/dpmodel/utils/dist_check.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
"""Minimum pairwise distance check for frame validity filtering."""
3+
4+
from __future__ import (
5+
annotations,
6+
)
7+
8+
import numpy as np
9+
10+
_MIN_PAIR_DIST_BLOCK_PAIRS = 262_144
11+
12+
13+
def compute_min_pair_dist_single(
14+
coord: np.ndarray,
15+
box: np.ndarray | None,
16+
atype: np.ndarray,
17+
stop_below: float | None = None,
18+
) -> float:
19+
"""Compute the minimum pairwise atomic distance for a single frame.
20+
21+
Parameters
22+
----------
23+
coord : np.ndarray
24+
Atomic coordinates, flattened with shape (natoms * 3,)
25+
or reshaped as (natoms, 3).
26+
box : np.ndarray or None
27+
Box vectors with shape (9,) for PBC, or None for non-PBC.
28+
atype : np.ndarray
29+
Atom types with shape (natoms,). Virtual atoms (type < 0)
30+
are excluded from the distance check.
31+
stop_below : float or None
32+
Optional early-stop threshold. If a block has any pair closer
33+
than this value, the block minimum is returned immediately.
34+
35+
Returns
36+
-------
37+
float
38+
Minimum pairwise distance. Returns inf if fewer than 2
39+
real atoms exist.
40+
"""
41+
coord = coord.reshape(-1, 3)
42+
43+
# === Step 1. Filter out virtual atoms ===
44+
real_mask = atype.ravel() >= 0
45+
real_coord = coord[real_mask]
46+
n_real = real_coord.shape[0]
47+
if n_real < 2:
48+
return float("inf")
49+
50+
# === Step 2. Prepare minimum image convention for PBC ===
51+
if box is not None:
52+
cell = box.reshape(3, 3)
53+
inv_cell = np.linalg.inv(cell)
54+
else:
55+
cell = None
56+
inv_cell = None
57+
58+
# === Step 3. Compute distances in bounded row blocks ===
59+
block_size = max(1, min(n_real, _MIN_PAIR_DIST_BLOCK_PAIRS // n_real))
60+
min_dist_sq = float("inf")
61+
stop_dist_sq = (
62+
float(stop_below) * float(stop_below)
63+
if stop_below is not None and stop_below > 0.0
64+
else None
65+
)
66+
for start in range(0, n_real, block_size):
67+
stop = min(start + block_size, n_real)
68+
diff = real_coord[np.newaxis, :, :] - real_coord[start:stop, np.newaxis, :]
69+
70+
if cell is not None and inv_cell is not None:
71+
frac_diff = diff @ inv_cell
72+
frac_diff -= np.round(frac_diff)
73+
diff = frac_diff @ cell
74+
75+
dist_sq = np.sum(diff * diff, axis=-1)
76+
rows = np.arange(stop - start, dtype=np.int64)
77+
dist_sq[rows, start + rows] = np.inf
78+
min_dist_sq = min(min_dist_sq, float(dist_sq.min()))
79+
if min_dist_sq == 0.0 or (
80+
stop_dist_sq is not None and min_dist_sq < stop_dist_sq
81+
):
82+
break
83+
84+
return float(np.sqrt(min_dist_sq))

deepmd/dpmodel/utils/lmdb_data.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
import msgpack
2222
import numpy as np
2323

24+
from deepmd.dpmodel.utils.dist_check import (
25+
compute_min_pair_dist_single,
26+
)
2427
from deepmd.env import (
2528
GLOBAL_ENER_FLOAT_PRECISION,
2629
GLOBAL_NP_FLOAT_PRECISION,
@@ -597,6 +600,29 @@ def __getitem__(self, index: int) -> dict[str, Any]:
597600
frame["natoms"] = fallback
598601
frame["real_natoms_vec"] = fallback
599602

603+
if "min_pair_dist" in self._data_requirements and "min_pair_dist" not in frame:
604+
box = frame.get("box")
605+
if box is not None and np.allclose(box, 0.0):
606+
box = None
607+
req = self._data_requirements["min_pair_dist"]
608+
min_pair_dist = float(
609+
req.get("default", 0.0)
610+
if isinstance(req, dict)
611+
else getattr(req, "default", 0.0)
612+
)
613+
frame["find_min_pair_dist"] = np.float32(1.0)
614+
frame["min_pair_dist"] = np.array(
615+
[
616+
compute_min_pair_dist_single(
617+
frame["coord"],
618+
box,
619+
frame["atype"],
620+
stop_below=min_pair_dist,
621+
)
622+
],
623+
dtype=self._resolve_dtype("min_pair_dist"),
624+
)
625+
600626
# Add find_* flags for all data keys present in the frame.
601627
# Core structural keys and metadata are excluded — only label-like
602628
# and auxiliary data keys get find_* flags.

deepmd/dpmodel/utils/nlist.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -355,9 +355,16 @@ def extend_coord_with_ghosts(
355355
shift_idx = xp.take(xyz, xp.argsort(xp.linalg.vector_norm(xyz, axis=1)), axis=0)
356356
ns, _ = shift_idx.shape
357357
nall = ns * nloc
358-
# shift_vec = xp.einsum("sd,fdk->fsk", shift_idx, cell)
359-
shift_vec = xp.tensordot(shift_idx, cell, axes=([1], [1]))
360-
shift_vec = xp.permute_dims(shift_vec, (1, 0, 2))
358+
if array_api_compat.is_jax_namespace(xp):
359+
# Avoid JAX internal errors in tensordot.
360+
shift_vec = xp.sum(
361+
shift_idx[xp.newaxis, :, :, xp.newaxis] * cell[:, xp.newaxis, :, :],
362+
axis=2,
363+
)
364+
else:
365+
# shift_vec = xp.einsum("sd,fdk->fsk", shift_idx, cell)
366+
shift_vec = xp.tensordot(shift_idx, cell, axes=([1], [1]))
367+
shift_vec = xp.permute_dims(shift_vec, (1, 0, 2))
361368
extend_coord = coord[:, None, :, :] + shift_vec[:, :, None, :]
362369
extend_atype = xp.tile(atype[:, :, xp.newaxis], (1, ns, 1))
363370
extend_aidx = xp.tile(aidx[:, :, xp.newaxis], (1, ns, 1))

0 commit comments

Comments
 (0)