From 805565e06b5ca571a6779655596d8fd060193105 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Wed, 18 Feb 2026 23:20:39 +0800 Subject: [PATCH 1/8] feat: add skills for adding new descriptors --- skills/add-descriptor/SKILL.md | 206 +++++++++ .../references/dpmodel-implementation.md | 104 +++++ .../references/test-patterns.md | 400 ++++++++++++++++++ 3 files changed, 710 insertions(+) create mode 100644 skills/add-descriptor/SKILL.md create mode 100644 skills/add-descriptor/references/dpmodel-implementation.md create mode 100644 skills/add-descriptor/references/test-patterns.md diff --git a/skills/add-descriptor/SKILL.md b/skills/add-descriptor/SKILL.md new file mode 100644 index 0000000000..2d06c817b0 --- /dev/null +++ b/skills/add-descriptor/SKILL.md @@ -0,0 +1,206 @@ +--- +name: add-descriptor +description: Guides through adding a new descriptor type to deepmd-kit. Covers implementing in dpmodel (array-API-compatible), wrapping for JAX/pt_expt backends, hard-coding for PT/PD, registering arguments, and writing all required tests. +license: LGPL-3.0-or-later +compatibility: Requires Python 3.10+, numpy, pytest. Optional backends for full testing (torch, jax, paddle). +metadata: + author: deepmd-kit + version: "1.0" +--- + +# Adding a New Descriptor to deepmd-kit + +Follow these steps in order. Each step lists files to create/modify and patterns to follow. + +## Step 1: Implement in dpmodel + +**Create** `deepmd/dpmodel/descriptor/.py` + +Inherit from `NativeOP` and `BaseDescriptor`. Register with decorators: + +```python +from deepmd.dpmodel import NativeOP +from .base_descriptor import BaseDescriptor + + +@BaseDescriptor.register("your_name") +@BaseDescriptor.register("alias_name") # optional aliases +class DescrptYourName(NativeOP, BaseDescriptor): ... +``` + +Key requirements: + +- `__init__`: initialize cutoff, sel, networks, davg/dstd statistics +- `call(coord_ext, atype_ext, nlist, mapping=None)`: forward pass returning `(descriptor, rot_mat, g2, h2, sw)` +- `serialize() -> dict`: save with `@class`, `type`, `@version`, `@variables` keys +- `deserialize(cls, data)`: reconstruct from dict +- Property/getter methods: `get_rcut`, `get_sel`, `get_dim_out`, `mixed_types`, etc. +- `__getitem__`/`__setitem__` for `davg`/`dstd` access via multiple key aliases + +All dpmodel code **must** use `array_api_compat` for cross-backend compatibility (numpy/torch/jax/paddle). See [references/dpmodel-implementation.md](references/dpmodel-implementation.md) for full method table, array API pitfalls, and utilities. + +**Reference implementations**: + +- Simple: `deepmd/dpmodel/descriptor/se_e2_a.py` +- Three-body: `deepmd/dpmodel/descriptor/se_t.py` +- Attention-based: `deepmd/dpmodel/descriptor/dpa1.py` + +## Step 2: Register + +**Edit** `deepmd/dpmodel/descriptor/__init__.py` — add import and `__all__` entry. + +**Edit** `deepmd/utils/argcheck.py` — register descriptor arguments: + +```python +@descrpt_args_plugin.register("your_name", alias=["alias"], doc="Description") +def descrpt_your_name_args() -> list[Argument]: + return [ + Argument("sel", [list[int], str], optional=True, default="auto", doc=doc_sel), + Argument("rcut", float, optional=True, default=6.0, doc=doc_rcut), + Argument("rcut_smth", float, optional=True, default=0.5, doc=doc_rcut_smth), + Argument( + "neuron", list[int], optional=True, default=[10, 20, 40], doc=doc_neuron + ), + # ... add all constructor parameters + ] +``` + +## Step 3: Wrap for JAX backend + +**Create** `deepmd/jax/descriptor/.py` + +Pattern: `@flax_module` decorator + custom `__setattr__` for attribute conversion. + +```python +from deepmd.dpmodel.descriptor.your_name import DescrptYourName as DescrptYourNameDP +from deepmd.jax.common import ArrayAPIVariable, flax_module, to_jax_array +from deepmd.jax.descriptor.base_descriptor import BaseDescriptor + + +@BaseDescriptor.register("your_name") +@flax_module +class DescrptYourName(DescrptYourNameDP): + def __setattr__(self, name, value): + if name in {"davg", "dstd"}: + value = to_jax_array(value) + if value is not None: + value = ArrayAPIVariable(value) + elif name in {"embeddings"}: + if value is not None: + value = NetworkCollection.deserialize(value.serialize()) + elif name == "env_mat": + pass # stateless + elif name == "emask": + value = PairExcludeMask(value.ntypes, value.exclude_types) + return super().__setattr__(name, value) +``` + +For nested sub-components, define wrapper classes bottom-up. See `deepmd/jax/descriptor/dpa1.py` for example. + +**Edit** `deepmd/jax/descriptor/__init__.py` — add import and `__all__` entry. + +## Step 4: Wrap for pt_expt backend + +**Create** `deepmd/pt_expt/descriptor/.py` + +Pattern: `@torch_module` decorator + `forward()` method delegating to `call()`. + +```python +from deepmd.dpmodel.descriptor.your_name import DescrptYourName as DescrptYourNameDP +from deepmd.pt_expt.common import torch_module +from deepmd.pt_expt.descriptor.base_descriptor import BaseDescriptor + + +@BaseDescriptor.register("your_name") +@torch_module +class DescrptYourName(DescrptYourNameDP): + def forward(self, *args, **kwargs): + return self.call(*args, **kwargs) +``` + +For nested sub-components, wrap and register bottom-up with `register_dpmodel_mapping`. See `deepmd/pt_expt/descriptor/se_t_tebd.py` + `se_t_tebd_block.py`. + +**Edit** `deepmd/pt_expt/descriptor/__init__.py` — add import and `__all__` entry. + +## Step 5: Hard-code for PT backend (if needed) + +**Create** `deepmd/pt/model/descriptor/.py` + +PT descriptors are fully reimplemented in PyTorch (not wrapping dpmodel). They inherit from `BaseDescriptor` and `torch.nn.Module`. Must implement `forward()`, `serialize()`, `deserialize()`. + +**Edit** `deepmd/pt/model/descriptor/__init__.py` — add import. + +Reference: `deepmd/pt/model/descriptor/se_a.py` + +## Step 6: Hard-code for PD backend (if needed) + +Same as PT but using Paddle. Inherit from `BaseDescriptor` and `paddle.nn.Layer`. + +**Edit** `deepmd/pd/model/descriptor/__init__.py` — add import. + +Reference: `deepmd/pd/model/descriptor/se_a.py` + +## Step 7: Write tests + +Seven test categories. See [references/test-patterns.md](references/test-patterns.md) for full code templates. + +| Test | File | Purpose | +| --------------------- | -------------------------------------------------------------- | ----------------------------------- | +| 7a. dpmodel | `source/tests/common/dpmodel/test_descriptor_.py` | Serialize/deserialize round-trip | +| 7b. pt_expt | `source/tests/pt_expt/descriptor/test_.py` | Consistency + exportable + make_fx | +| 7c. PT | `source/tests/pt/model/test_descriptor_.py` | PT hard-coded tests (if applicable) | +| 7d. PD | `source/tests/pd/model/test_descriptor_.py` | PD hard-coded tests (if applicable) | +| 7e. array_api_strict | `source/tests/array_api_strict/descriptor/.py` | Wrapper for consistency tests | +| 7f. Universal dpmodel | `source/tests/universal/dpmodel/descriptor/test_descriptor.py` | Add parameterized entry | +| 7g. Universal PT | `source/tests/universal/pt/descriptor/test_descriptor.py` | Add parameterized entry | +| 7h. Consistency | `source/tests/consistent/descriptor/test_.py` | Cross-backend comparison | + +## Verification + +```bash +# dpmodel self-consistency +python -m pytest source/tests/common/dpmodel/test_descriptor_.py -v + +# pt_expt unit tests +python -m pytest source/tests/pt_expt/descriptor/test_.py -v + +# Cross-backend consistency +python -m pytest source/tests/consistent/descriptor/test_.py -v + +# PT/PD unit tests (if hard-coded) +python -m pytest source/tests/pt/model/test_descriptor_.py -v +python -m pytest source/tests/pd/model/test_descriptor_.py -v + +# Quick smoke test +python -c " +from deepmd.dpmodel.descriptor import DescrptYourName +d = DescrptYourName(rcut=6.0, rcut_smth=1.8, sel=[20, 20]) +d2 = DescrptYourName.deserialize(d.serialize()) +print('Round-trip OK:', d.get_dim_out() == d2.get_dim_out()) +" +``` + +## Files summary + +| Step | Action | File | +| ---- | ------ | -------------------------------------------------------------- | +| 1 | Create | `deepmd/dpmodel/descriptor/.py` | +| 2 | Edit | `deepmd/dpmodel/descriptor/__init__.py` | +| 2 | Edit | `deepmd/utils/argcheck.py` | +| 3 | Create | `deepmd/jax/descriptor/.py` | +| 3 | Edit | `deepmd/jax/descriptor/__init__.py` | +| 4 | Create | `deepmd/pt_expt/descriptor/.py` | +| 4 | Edit | `deepmd/pt_expt/descriptor/__init__.py` | +| 5 | Create | `deepmd/pt/model/descriptor/.py` (if needed) | +| 5 | Edit | `deepmd/pt/model/descriptor/__init__.py` (if needed) | +| 6 | Create | `deepmd/pd/model/descriptor/.py` (if needed) | +| 6 | Edit | `deepmd/pd/model/descriptor/__init__.py` (if needed) | +| 7a | Create | `source/tests/common/dpmodel/test_descriptor_.py` | +| 7b | Create | `source/tests/pt_expt/descriptor/test_.py` | +| 7c | Create | `source/tests/pt/model/test_descriptor_.py` (if PT) | +| 7d | Create | `source/tests/pd/model/test_descriptor_.py` (if PD) | +| 7e | Create | `source/tests/array_api_strict/descriptor/.py` | +| 7e | Edit | `source/tests/array_api_strict/descriptor/__init__.py` | +| 7f | Edit | `source/tests/universal/dpmodel/descriptor/test_descriptor.py` | +| 7g | Edit | `source/tests/universal/pt/descriptor/test_descriptor.py` | +| 7h | Create | `source/tests/consistent/descriptor/test_.py` | diff --git a/skills/add-descriptor/references/dpmodel-implementation.md b/skills/add-descriptor/references/dpmodel-implementation.md new file mode 100644 index 0000000000..f4717b41a2 --- /dev/null +++ b/skills/add-descriptor/references/dpmodel-implementation.md @@ -0,0 +1,104 @@ +# dpmodel Implementation Details + +## Required methods + +| Method | Purpose | +| ------------------------------------------------------- | ------------------------------------------------------------ | +| `__init__(self, rcut, rcut_smth, sel, ...)` | Initialize cutoff, sel, networks, statistics | +| `call(self, coord_ext, atype_ext, nlist, mapping=None)` | Forward pass, returns `(descriptor, rot_mat, g2, h2, sw)` | +| `serialize(self) -> dict` | Save to dict with `@class`, `type`, `@version`, `@variables` | +| `deserialize(cls, data) -> Self` | Reconstruct from dict | +| `get_rcut() -> float` | Cutoff radius | +| `get_rcut_smth() -> float` | Smooth cutoff | +| `get_sel() -> list[int]` | Neighbor selection per type | +| `get_ntypes() -> int` | Number of atom types | +| `get_type_map() -> list[str]` | Type map | +| `get_dim_out() -> int` | Output descriptor dimension | +| `get_dim_emb() -> int` | Embedding dimension | +| `get_env_protection() -> float` | Environment protection value | +| `mixed_types() -> bool` | Whether descriptor mixes types | +| `has_message_passing() -> bool` | Whether it uses message passing | +| `need_sorted_nlist_for_lower() -> bool` | Whether nlist must be sorted | +| `compute_input_stats(merged, path)` | Compute davg/dstd from data | +| `set_stat_mean_and_stddev(mean, stddev)` | Set statistics | +| `get_stat_mean_and_stddev()` | Get statistics | +| `change_type_map(type_map, ...)` | Handle type map changes | +| `share_params(base_class, shared_level, resume)` | Parameter sharing | +| `update_sel(cls, train_data, type_map, local_jdata)` | Auto-update sel | + +## Statistics handling + +Support both naming conventions via `__getitem__`/`__setitem__`: + +```python +def __setitem__(self, key, value): + if key in ("avg", "data_avg", "davg"): + self.davg = value + elif key in ("std", "data_std", "dstd"): + self.dstd = value + else: + raise KeyError(key) + + +def __getitem__(self, key): + if key in ("avg", "data_avg", "davg"): + return self.davg + elif key in ("std", "data_std", "dstd"): + return self.dstd + else: + raise KeyError(key) +``` + +## Key utilities + +| Utility | Import from | Purpose | +| ------------------- | ----------------------------------- | ------------------------------ | +| `EnvMat` | `deepmd.dpmodel.utils.env_mat` | Environment matrix computation | +| `EmbeddingNet` | `deepmd.dpmodel.utils.network` | Embedding neural network | +| `NetworkCollection` | `deepmd.dpmodel.utils.network` | Manages type-indexed networks | +| `PairExcludeMask` | `deepmd.dpmodel.utils.exclude_mask` | Type exclusion pairs | +| `EnvMatStatSe` | `deepmd.dpmodel.utils.env_mat_stat` | Statistics computation | + +## Array API compatibility (CRITICAL) + +All dpmodel code must use `array_api_compat` to work across numpy/torch/jax/paddle: + +```python +import array_api_compat + +xp = array_api_compat.array_namespace(coord_ext) +device = array_api_compat.device(coord_ext) +``` + +Rules: + +1. **Never use `np.einsum` on arrays that might be torch tensors** — torch disables `__array_function__` so `np.einsum` fails on tensors with `requires_grad=True`. Use `xp.sum` with broadcasting: + + ```python + # BAD: np.einsum("lni,lnj->lij", gg, tr) + # GOOD: xp.sum(gg[:, :, :, None] * tr[:, :, None, :], axis=1) + ``` + +2. **`xp.zeros`/`xp.ones` must include `device=`** — omitting device can trigger CUDA init or create tensors on wrong device: + + ```python + # BAD: xp.zeros([2, 1], dtype=nlist.dtype) + # GOOD: xp.zeros([2, 1], dtype=nlist.dtype, device=array_api_compat.device(nlist)) + ``` + +3. **`xp.split` with `axis=` keyword doesn't work for torch** — use slicing: + + ```python + # BAD: g2, h2 = xp.split(dmatrix, [1], axis=-1) + # GOOD: g2, h2 = dmatrix[..., :1], dmatrix[..., 1:] + ``` + +4. **`xp_take_along_axis` indices must be int64 for torch**. + +5. **Don't maintain separate ArrayAPI subclasses** — dpmodel classes should be array_api compatible directly. + +6. **Boolean fancy indexing (`arr[mask]`) is not array-API compatible** — use mask multiplication: + ```python + # BAD: gr[ti_mask] += gr_tmp + # GOOD: gr += gr_tmp * xp.astype(mask[:, None, None], gr_tmp.dtype) + ``` diff --git a/skills/add-descriptor/references/test-patterns.md b/skills/add-descriptor/references/test-patterns.md new file mode 100644 index 0000000000..5de3543843 --- /dev/null +++ b/skills/add-descriptor/references/test-patterns.md @@ -0,0 +1,400 @@ +# Test Patterns for Descriptors + +## 7a. dpmodel self-consistency test + +**Create** `source/tests/common/dpmodel/test_descriptor_.py` + +```python +import unittest +import numpy as np +from deepmd.dpmodel.descriptor import DescrptYourName +from ...seed import GLOBAL_SEED +from .case_single_frame_with_nlist import TestCaseSingleFrameWithNlist + + +class TestDescrptYourName(unittest.TestCase, TestCaseSingleFrameWithNlist): + def setUp(self) -> None: + TestCaseSingleFrameWithNlist.setUp(self) + + def test_self_consistency(self) -> None: + rng = np.random.default_rng(GLOBAL_SEED) + nf, nloc, nnei = self.nlist.shape + davg = rng.normal( + size=(self.nt, nnei, 4) + ) # 4 for full env mat, 1 for radial-only + dstd = 0.1 + np.abs(rng.normal(size=(self.nt, nnei, 4))) + + em0 = DescrptYourName(self.rcut, self.rcut_smth, self.sel) + em0.davg = davg + em0.dstd = dstd + em1 = DescrptYourName.deserialize(em0.serialize()) + mm0 = em0.call(self.coord_ext, self.atype_ext, self.nlist) + mm1 = em1.call(self.coord_ext, self.atype_ext, self.nlist) + for ii in [0, 4]: # descriptor and sw + np.testing.assert_allclose(mm0[ii], mm1[ii]) +``` + +Reference: `source/tests/common/dpmodel/test_descriptor_se_t.py` + +## 7b. pt_expt unit tests + +**Create** `source/tests/pt_expt/descriptor/test_.py` + +Three test types: consistency, exportable, make_fx. Use `itertools.product` loops inside methods (not `pytest.mark.parametrize`) when the class inherits `unittest.TestCase`. + +```python +import itertools +import unittest +import numpy as np +import torch +from torch.fx.experimental.proxy_tensor import make_fx +from deepmd.dpmodel.descriptor import DescrptYourName as DPDescrptYourName +from deepmd.pt_expt.descriptor.your_name import DescrptYourName +from deepmd.pt_expt.utils import env +from deepmd.pt_expt.utils.env import PRECISION_DICT +from ...pt.model.test_env_mat import TestCaseSingleFrameWithNlist +from ...pt.model.test_mlp import get_tols +from ...seed import GLOBAL_SEED + + +class TestDescrptYourName(unittest.TestCase, TestCaseSingleFrameWithNlist): + def setUp(self) -> None: + TestCaseSingleFrameWithNlist.setUp(self) + self.device = env.DEVICE + + def test_consistency(self) -> None: + rng = np.random.default_rng(GLOBAL_SEED) + _, _, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = 0.1 + np.abs(rng.normal(size=(self.nt, nnei, 4))) + + for idt, prec in itertools.product([False, True], ["float64", "float32"]): + dtype = PRECISION_DICT[prec] + rtol, atol = get_tols(prec) + dd0 = DescrptYourName( + self.rcut, + self.rcut_smth, + self.sel, + precision=prec, + resnet_dt=idt, + seed=GLOBAL_SEED, + ).to(self.device) + dd0.davg = torch.tensor(davg, dtype=dtype, device=self.device) + dd0.dstd = torch.tensor(dstd, dtype=dtype, device=self.device) + # Forward + rd0, _, _, _, _ = dd0( + torch.tensor(self.coord_ext, dtype=dtype, device=self.device), + torch.tensor(self.atype_ext, dtype=int, device=self.device), + torch.tensor(self.nlist, dtype=int, device=self.device), + ) + # Serialize/deserialize round-trip + dd1 = DescrptYourName.deserialize(dd0.serialize()) + rd1, _, _, _, sw1 = dd1( + torch.tensor(self.coord_ext, dtype=dtype, device=self.device), + torch.tensor(self.atype_ext, dtype=int, device=self.device), + torch.tensor(self.nlist, dtype=int, device=self.device), + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy(), + rd1.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + ) + # Permutation equivariance + np.testing.assert_allclose( + rd0.detach().cpu().numpy()[0][self.perm[: self.nloc]], + rd0.detach().cpu().numpy()[1], + rtol=rtol, + atol=atol, + ) + # Compare with dpmodel + dd2 = DPDescrptYourName.deserialize(dd0.serialize()) + rd2, _, _, _, sw2 = dd2.call(self.coord_ext, self.atype_ext, self.nlist) + np.testing.assert_allclose( + rd1.detach().cpu().numpy(), rd2, rtol=rtol, atol=atol + ) + np.testing.assert_allclose( + sw1.detach().cpu().numpy(), sw2, rtol=rtol, atol=atol + ) + + def test_exportable(self) -> None: + rng = np.random.default_rng(GLOBAL_SEED) + _, _, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = 0.1 + np.abs(rng.normal(size=(self.nt, nnei, 4))) + + for idt, prec in itertools.product([False, True], ["float64", "float32"]): + dtype = PRECISION_DICT[prec] + dd0 = DescrptYourName( + self.rcut, + self.rcut_smth, + self.sel, + precision=prec, + resnet_dt=idt, + seed=GLOBAL_SEED, + ).to(self.device) + dd0.davg = torch.tensor(davg, dtype=dtype, device=self.device) + dd0.dstd = torch.tensor(dstd, dtype=dtype, device=self.device) + dd0 = dd0.eval() + inputs = ( + torch.tensor(self.coord_ext, dtype=dtype, device=self.device), + torch.tensor(self.atype_ext, dtype=int, device=self.device), + torch.tensor(self.nlist, dtype=int, device=self.device), + ) + torch.export.export(dd0, inputs) + + def test_make_fx(self) -> None: + """Verify make_fx traces forward + autograd (for forward_lower).""" + rng = np.random.default_rng(GLOBAL_SEED) + _, _, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = 0.1 + np.abs(rng.normal(size=(self.nt, nnei, 4))) + + for idt, prec in itertools.product([False, True], ["float64", "float32"]): + dtype = PRECISION_DICT[prec] + rtol, atol = get_tols(prec) + dd0 = DescrptYourName( + self.rcut, + self.rcut_smth, + self.sel, + precision=prec, + resnet_dt=idt, + seed=GLOBAL_SEED, + ).to(self.device) + dd0.davg = torch.tensor(davg, dtype=dtype, device=self.device) + dd0.dstd = torch.tensor(dstd, dtype=dtype, device=self.device) + dd0 = dd0.eval() + + coord_ext = torch.tensor(self.coord_ext, dtype=dtype, device=self.device) + atype_ext = torch.tensor(self.atype_ext, dtype=int, device=self.device) + nlist = torch.tensor(self.nlist, dtype=int, device=self.device) + + def fn(coord_ext, atype_ext, nlist): + coord_ext = coord_ext.detach().requires_grad_(True) + rd = dd0(coord_ext, atype_ext, nlist)[0] + grad = torch.autograd.grad(rd.sum(), coord_ext, create_graph=False)[0] + return rd, grad + + rd_eager, grad_eager = fn(coord_ext, atype_ext, nlist) + traced = make_fx(fn)(coord_ext, atype_ext, nlist) + rd_traced, grad_traced = traced(coord_ext, atype_ext, nlist) + np.testing.assert_allclose( + rd_eager.detach().cpu().numpy(), + rd_traced.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + ) + np.testing.assert_allclose( + grad_eager.detach().cpu().numpy(), + grad_traced.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + ) +``` + +Reference: `source/tests/pt_expt/descriptor/test_se_t.py` + +## 7e. array_api_strict wrapper + +**Create** `source/tests/array_api_strict/descriptor/.py` + +```python +from typing import Any +from deepmd.dpmodel.descriptor.your_name import DescrptYourName as DescrptYourNameDP +from ..common import to_array_api_strict_array +from ..utils.exclude_mask import PairExcludeMask +from ..utils.network import NetworkCollection +from .base_descriptor import BaseDescriptor + + +@BaseDescriptor.register("your_name") +class DescrptYourName(DescrptYourNameDP): + def __setattr__(self, name: str, value: Any) -> None: + if name in {"dstd", "davg"}: + value = to_array_api_strict_array(value) + elif name in {"embeddings"}: + if value is not None: + value = NetworkCollection.deserialize(value.serialize()) + elif name == "env_mat": + pass + elif name == "emask": + value = PairExcludeMask(value.ntypes, value.exclude_types) + return super().__setattr__(name, value) +``` + +**Edit** `source/tests/array_api_strict/descriptor/__init__.py` — add import and `__all__` entry. + +Reference: `source/tests/array_api_strict/descriptor/se_e2_r.py` + +## 7h. Cross-backend consistency test + +**Create** `source/tests/consistent/descriptor/test_.py` + +```python +import unittest +from typing import Any +import numpy as np +from deepmd.dpmodel.descriptor.your_name import DescrptYourName as DescrptYourNameDP +from deepmd.env import GLOBAL_NP_FLOAT_PRECISION +from deepmd.utils.argcheck import descrpt_your_name_args +from ..common import ( + INSTALLED_ARRAY_API_STRICT, + INSTALLED_JAX, + INSTALLED_PT, + INSTALLED_PT_EXPT, + INSTALLED_TF, + CommonTest, + parameterized, +) +from .common import DescriptorTest + +# Conditional imports for each backend +if INSTALLED_PT: + from deepmd.pt.model.descriptor.your_name import DescrptYourName as YourNamePT +else: + YourNamePT = None +if INSTALLED_PT_EXPT: + from deepmd.pt_expt.descriptor.your_name import DescrptYourName as YourNamePTExpt +else: + YourNamePTExpt = None +if INSTALLED_TF: + from deepmd.tf.descriptor.your_name import DescrptYourName as YourNameTF +else: + YourNameTF = None +if INSTALLED_JAX: + from deepmd.jax.descriptor.your_name import DescrptYourName as YourNameJAX +else: + YourNameJAX = None +if INSTALLED_ARRAY_API_STRICT: + from ...array_api_strict.descriptor.your_name import ( + DescrptYourName as YourNameStrict, + ) +else: + YourNameStrict = None + + +@parameterized( + (True, False), # resnet_dt + ("float32", "float64"), # precision +) +class TestYourName(CommonTest, DescriptorTest, unittest.TestCase): + @property + def data(self) -> dict: + resnet_dt, precision = self.param + return { + "sel": [9, 10], + "rcut_smth": 5.80, + "rcut": 6.00, + "neuron": [6, 12, 24], + "resnet_dt": resnet_dt, + "precision": precision, + "seed": 1145141919810, + "activation_function": "relu", + } + + # Set skip_* properties based on which backends are available + @property + def skip_pt(self): + return not INSTALLED_PT + + @property + def skip_pt_expt(self): + return CommonTest.skip_pt_expt + + @property + def skip_dp(self): + return CommonTest.skip_dp + + @property + def skip_jax(self): + return not INSTALLED_JAX + + @property + def skip_array_api_strict(self): + return not INSTALLED_ARRAY_API_STRICT + + tf_class = YourNameTF + dp_class = DescrptYourNameDP + pt_class = YourNamePT + pt_expt_class = YourNamePTExpt + jax_class = YourNameJAX + array_api_strict_class = YourNameStrict + args = descrpt_your_name_args() + + def setUp(self) -> None: + CommonTest.setUp(self) + self.ntypes = 2 + self.coords = np.array( + [ + 12.83, + 2.56, + 2.18, + 12.09, + 2.87, + 2.74, + 0.25, + 3.32, + 1.68, + 3.36, + 3.00, + 1.81, + 3.51, + 2.51, + 2.60, + 4.27, + 3.22, + 1.56, + ], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ) + self.atype = np.array([0, 1, 1, 0, 1, 1], dtype=np.int32) + self.box = np.array( + [13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ) + self.natoms = np.array([6, 6, 2, 4], dtype=np.int32) + + # Implement eval_* methods using self.eval_*_descriptor() helpers + def eval_dp(self, dp_obj): + return self.eval_dp_descriptor( + dp_obj, self.natoms, self.coords, self.atype, self.box + ) + + def eval_pt(self, pt_obj): + return self.eval_pt_descriptor( + pt_obj, self.natoms, self.coords, self.atype, self.box + ) + + def eval_pt_expt(self, pt_expt_obj): + return self.eval_pt_expt_descriptor( + pt_expt_obj, self.natoms, self.coords, self.atype, self.box + ) + + def eval_jax(self, jax_obj): + return self.eval_jax_descriptor( + jax_obj, self.natoms, self.coords, self.atype, self.box + ) + + def eval_array_api_strict(self, obj): + return self.eval_array_api_strict_descriptor( + obj, self.natoms, self.coords, self.atype, self.box + ) + + def extract_ret(self, ret, backend): + return (ret[0],) + + # For mixed_types descriptors (dpa1, dpa2, dpa3, se_atten_v2), + # pass mixed_types=True to eval_*_descriptor calls. + + @property + def rtol(self) -> float: + _, precision = self.param + return 1e-10 if precision == "float64" else 1e-4 + + @property + def atol(self) -> float: + _, precision = self.param + return 1e-10 if precision == "float64" else 1e-4 +``` + +Reference: `source/tests/consistent/descriptor/test_se_r.py` From 9d892ebb11e967e4b2a1e257bfcb6a9f85361058 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Wed, 18 Feb 2026 23:34:05 +0800 Subject: [PATCH 2/8] update due to updated torch_module --- skills/add-descriptor/SKILL.md | 36 ++++++++++++++++++++++++++++++---- 1 file changed, 32 insertions(+), 4 deletions(-) diff --git a/skills/add-descriptor/SKILL.md b/skills/add-descriptor/SKILL.md index 2d06c817b0..fe4dd11caf 100644 --- a/skills/add-descriptor/SKILL.md +++ b/skills/add-descriptor/SKILL.md @@ -103,7 +103,13 @@ For nested sub-components, define wrapper classes bottom-up. See `deepmd/jax/des **Create** `deepmd/pt_expt/descriptor/.py` -Pattern: `@torch_module` decorator + `forward()` method delegating to `call()`. +The `@torch_module` decorator handles everything automatically: + +- Auto-generates `forward()` delegating to `call()` (and `forward_lower()` from `call_lower()`) +- Auto-generates `__setattr__` that converts numpy arrays to torch buffers and dpmodel objects to pt_expt modules via a converter registry +- Any unregistered `NativeOP` assigned as an attribute will raise `TypeError` — register it first + +Simple descriptors (no custom sub-components) need only an empty body: ```python from deepmd.dpmodel.descriptor.your_name import DescrptYourName as DescrptYourNameDP @@ -114,11 +120,33 @@ from deepmd.pt_expt.descriptor.base_descriptor import BaseDescriptor @BaseDescriptor.register("your_name") @torch_module class DescrptYourName(DescrptYourNameDP): - def forward(self, *args, **kwargs): - return self.call(*args, **kwargs) + pass ``` -For nested sub-components, wrap and register bottom-up with `register_dpmodel_mapping`. See `deepmd/pt_expt/descriptor/se_t_tebd.py` + `se_t_tebd_block.py`. +Standard dpmodel sub-components (`NetworkCollection`, `EmbeddingNet`, `PairExcludeMask`, `EnvMat`, `TypeEmbedNet`) are pre-registered in `deepmd/pt_expt/utils/` and converted automatically. No `__setattr__` override needed. + +For **custom sub-components** (e.g., a new block class inheriting `NativeOP`), create a separate wrapper file and register bottom-up with `register_dpmodel_mapping`: + +```python +# deepmd/pt_expt/descriptor/your_block.py +from deepmd.dpmodel.descriptor.your_block import YourBlock as YourBlockDP +from deepmd.pt_expt.common import register_dpmodel_mapping, torch_module + + +@torch_module +class YourBlock(YourBlockDP): + pass + + +register_dpmodel_mapping( + YourBlockDP, + lambda v: YourBlock.deserialize(v.serialize()), +) +``` + +Then import this module in `deepmd/pt_expt/descriptor/__init__.py` for its side effect (the registration must happen before the parent descriptor is instantiated). + +Reference: `deepmd/pt_expt/descriptor/se_t_tebd.py` + `se_t_tebd_block.py` **Edit** `deepmd/pt_expt/descriptor/__init__.py` — add import and `__all__` entry. From 779c37d95adf2eb1d0f8f178028dc8af2c74c93b Mon Sep 17 00:00:00 2001 From: Han Wang Date: Wed, 18 Feb 2026 23:48:04 +0800 Subject: [PATCH 3/8] update based on latest feat-descriptor-1 branch --- skills/add-descriptor/SKILL.md | 24 +- .../references/test-patterns.md | 330 ++++++++++-------- 2 files changed, 204 insertions(+), 150 deletions(-) diff --git a/skills/add-descriptor/SKILL.md b/skills/add-descriptor/SKILL.md index fe4dd11caf..7d50b84a3b 100644 --- a/skills/add-descriptor/SKILL.md +++ b/skills/add-descriptor/SKILL.md @@ -5,7 +5,7 @@ license: LGPL-3.0-or-later compatibility: Requires Python 3.10+, numpy, pytest. Optional backends for full testing (torch, jax, paddle). metadata: author: deepmd-kit - version: "1.0" + version: "2.0" --- # Adding a New Descriptor to deepmd-kit @@ -172,16 +172,18 @@ Reference: `deepmd/pd/model/descriptor/se_a.py` Seven test categories. See [references/test-patterns.md](references/test-patterns.md) for full code templates. -| Test | File | Purpose | -| --------------------- | -------------------------------------------------------------- | ----------------------------------- | -| 7a. dpmodel | `source/tests/common/dpmodel/test_descriptor_.py` | Serialize/deserialize round-trip | -| 7b. pt_expt | `source/tests/pt_expt/descriptor/test_.py` | Consistency + exportable + make_fx | -| 7c. PT | `source/tests/pt/model/test_descriptor_.py` | PT hard-coded tests (if applicable) | -| 7d. PD | `source/tests/pd/model/test_descriptor_.py` | PD hard-coded tests (if applicable) | -| 7e. array_api_strict | `source/tests/array_api_strict/descriptor/.py` | Wrapper for consistency tests | -| 7f. Universal dpmodel | `source/tests/universal/dpmodel/descriptor/test_descriptor.py` | Add parameterized entry | -| 7g. Universal PT | `source/tests/universal/pt/descriptor/test_descriptor.py` | Add parameterized entry | -| 7h. Consistency | `source/tests/consistent/descriptor/test_.py` | Cross-backend comparison | +pt_expt tests use `pytest.mark.parametrize` (not `itertools.product`), do not inherit from `unittest.TestCase`, and use `setup_method` (not `setUp`). + +| Test | File | Purpose | +| --------------------- | -------------------------------------------------------------- | ------------------------------------------------- | +| 7a. dpmodel | `source/tests/common/dpmodel/test_descriptor_.py` | Serialize/deserialize round-trip | +| 7b. pt_expt | `source/tests/pt_expt/descriptor/test_.py` | Consistency + exportable + make_fx (float64 only) | +| 7c. PT | `source/tests/pt/model/test_descriptor_.py` | PT hard-coded tests (if applicable) | +| 7d. PD | `source/tests/pd/model/test_descriptor_.py` | PD hard-coded tests (if applicable) | +| 7e. array_api_strict | `source/tests/array_api_strict/descriptor/.py` | Wrapper for consistency tests | +| 7f. Universal dpmodel | `source/tests/universal/dpmodel/descriptor/test_descriptor.py` | Add parameterized entry | +| 7g. Universal PT | `source/tests/universal/pt/descriptor/test_descriptor.py` | Add parameterized entry | +| 7h. Consistency | `source/tests/consistent/descriptor/test_.py` | Cross-backend + API consistency | ## Verification diff --git a/skills/add-descriptor/references/test-patterns.md b/skills/add-descriptor/references/test-patterns.md index 5de3543843..0c58883f46 100644 --- a/skills/add-descriptor/references/test-patterns.md +++ b/skills/add-descriptor/references/test-patterns.md @@ -6,8 +6,11 @@ ```python import unittest + import numpy as np + from deepmd.dpmodel.descriptor import DescrptYourName + from ...seed import GLOBAL_SEED from .case_single_frame_with_nlist import TestCaseSingleFrameWithNlist @@ -40,156 +43,167 @@ Reference: `source/tests/common/dpmodel/test_descriptor_se_t.py` **Create** `source/tests/pt_expt/descriptor/test_.py` -Three test types: consistency, exportable, make_fx. Use `itertools.product` loops inside methods (not `pytest.mark.parametrize`) when the class inherits `unittest.TestCase`. +Three test types: consistency, exportable, make_fx. Use `pytest.mark.parametrize` with trailing comments explaining each parameter. Do **not** inherit from `unittest.TestCase`. Use `setup_method` instead of `setUp`. ```python -import itertools -import unittest import numpy as np +import pytest import torch from torch.fx.experimental.proxy_tensor import make_fx + from deepmd.dpmodel.descriptor import DescrptYourName as DPDescrptYourName from deepmd.pt_expt.descriptor.your_name import DescrptYourName from deepmd.pt_expt.utils import env from deepmd.pt_expt.utils.env import PRECISION_DICT + from ...pt.model.test_env_mat import TestCaseSingleFrameWithNlist from ...pt.model.test_mlp import get_tols from ...seed import GLOBAL_SEED -class TestDescrptYourName(unittest.TestCase, TestCaseSingleFrameWithNlist): - def setUp(self) -> None: +class TestDescrptYourName(TestCaseSingleFrameWithNlist): + def setup_method(self) -> None: TestCaseSingleFrameWithNlist.setUp(self) self.device = env.DEVICE - def test_consistency(self) -> None: + @pytest.mark.parametrize("idt", [False, True]) # resnet_dt + @pytest.mark.parametrize("prec", ["float64", "float32"]) # precision + def test_consistency(self, idt, prec) -> None: rng = np.random.default_rng(GLOBAL_SEED) _, _, nnei = self.nlist.shape davg = rng.normal(size=(self.nt, nnei, 4)) - dstd = 0.1 + np.abs(rng.normal(size=(self.nt, nnei, 4))) + dstd = rng.normal(size=(self.nt, nnei, 4)) + dstd = 0.1 + np.abs(dstd) + + dtype = PRECISION_DICT[prec] + rtol, atol = get_tols(prec) + err_msg = f"idt={idt} prec={prec}" + dd0 = DescrptYourName( + self.rcut, + self.rcut_smth, + self.sel, + precision=prec, + resnet_dt=idt, + seed=GLOBAL_SEED, + ).to(self.device) + dd0.davg = torch.tensor(davg, dtype=dtype, device=self.device) + dd0.dstd = torch.tensor(dstd, dtype=dtype, device=self.device) + rd0, _, _, _, _ = dd0( + torch.tensor(self.coord_ext, dtype=dtype, device=self.device), + torch.tensor(self.atype_ext, dtype=int, device=self.device), + torch.tensor(self.nlist, dtype=int, device=self.device), + ) + # Serialize/deserialize round-trip + dd1 = DescrptYourName.deserialize(dd0.serialize()) + rd1, _, _, _, sw1 = dd1( + torch.tensor(self.coord_ext, dtype=dtype, device=self.device), + torch.tensor(self.atype_ext, dtype=int, device=self.device), + torch.tensor(self.nlist, dtype=int, device=self.device), + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy(), + rd1.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + # Permutation equivariance + np.testing.assert_allclose( + rd0.detach().cpu().numpy()[0][self.perm[: self.nloc]], + rd0.detach().cpu().numpy()[1], + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + # Compare with dpmodel + dd2 = DPDescrptYourName.deserialize(dd0.serialize()) + rd2, _, _, _, sw2 = dd2.call( + self.coord_ext, + self.atype_ext, + self.nlist, + ) + np.testing.assert_allclose( + rd1.detach().cpu().numpy(), rd2, rtol=rtol, atol=atol, err_msg=err_msg + ) + np.testing.assert_allclose( + sw1.detach().cpu().numpy(), sw2, rtol=rtol, atol=atol, err_msg=err_msg + ) - for idt, prec in itertools.product([False, True], ["float64", "float32"]): - dtype = PRECISION_DICT[prec] - rtol, atol = get_tols(prec) - dd0 = DescrptYourName( - self.rcut, - self.rcut_smth, - self.sel, - precision=prec, - resnet_dt=idt, - seed=GLOBAL_SEED, - ).to(self.device) - dd0.davg = torch.tensor(davg, dtype=dtype, device=self.device) - dd0.dstd = torch.tensor(dstd, dtype=dtype, device=self.device) - # Forward - rd0, _, _, _, _ = dd0( - torch.tensor(self.coord_ext, dtype=dtype, device=self.device), - torch.tensor(self.atype_ext, dtype=int, device=self.device), - torch.tensor(self.nlist, dtype=int, device=self.device), - ) - # Serialize/deserialize round-trip - dd1 = DescrptYourName.deserialize(dd0.serialize()) - rd1, _, _, _, sw1 = dd1( - torch.tensor(self.coord_ext, dtype=dtype, device=self.device), - torch.tensor(self.atype_ext, dtype=int, device=self.device), - torch.tensor(self.nlist, dtype=int, device=self.device), - ) - np.testing.assert_allclose( - rd0.detach().cpu().numpy(), - rd1.detach().cpu().numpy(), - rtol=rtol, - atol=atol, - ) - # Permutation equivariance - np.testing.assert_allclose( - rd0.detach().cpu().numpy()[0][self.perm[: self.nloc]], - rd0.detach().cpu().numpy()[1], - rtol=rtol, - atol=atol, - ) - # Compare with dpmodel - dd2 = DPDescrptYourName.deserialize(dd0.serialize()) - rd2, _, _, _, sw2 = dd2.call(self.coord_ext, self.atype_ext, self.nlist) - np.testing.assert_allclose( - rd1.detach().cpu().numpy(), rd2, rtol=rtol, atol=atol - ) - np.testing.assert_allclose( - sw1.detach().cpu().numpy(), sw2, rtol=rtol, atol=atol - ) - - def test_exportable(self) -> None: + @pytest.mark.parametrize("idt", [False, True]) # resnet_dt + @pytest.mark.parametrize("prec", ["float64", "float32"]) # precision + def test_exportable(self, idt, prec) -> None: rng = np.random.default_rng(GLOBAL_SEED) _, _, nnei = self.nlist.shape davg = rng.normal(size=(self.nt, nnei, 4)) - dstd = 0.1 + np.abs(rng.normal(size=(self.nt, nnei, 4))) + dstd = rng.normal(size=(self.nt, nnei, 4)) + dstd = 0.1 + np.abs(dstd) + + dtype = PRECISION_DICT[prec] + dd0 = DescrptYourName( + self.rcut, + self.rcut_smth, + self.sel, + precision=prec, + resnet_dt=idt, + seed=GLOBAL_SEED, + ).to(self.device) + dd0.davg = torch.tensor(davg, dtype=dtype, device=self.device) + dd0.dstd = torch.tensor(dstd, dtype=dtype, device=self.device) + dd0 = dd0.eval() + inputs = ( + torch.tensor(self.coord_ext, dtype=dtype, device=self.device), + torch.tensor(self.atype_ext, dtype=int, device=self.device), + torch.tensor(self.nlist, dtype=int, device=self.device), + ) + torch.export.export(dd0, inputs) - for idt, prec in itertools.product([False, True], ["float64", "float32"]): - dtype = PRECISION_DICT[prec] - dd0 = DescrptYourName( - self.rcut, - self.rcut_smth, - self.sel, - precision=prec, - resnet_dt=idt, - seed=GLOBAL_SEED, - ).to(self.device) - dd0.davg = torch.tensor(davg, dtype=dtype, device=self.device) - dd0.dstd = torch.tensor(dstd, dtype=dtype, device=self.device) - dd0 = dd0.eval() - inputs = ( - torch.tensor(self.coord_ext, dtype=dtype, device=self.device), - torch.tensor(self.atype_ext, dtype=int, device=self.device), - torch.tensor(self.nlist, dtype=int, device=self.device), - ) - torch.export.export(dd0, inputs) - - def test_make_fx(self) -> None: + @pytest.mark.parametrize("prec", ["float64"]) # precision — float64 only + def test_make_fx(self, prec) -> None: """Verify make_fx traces forward + autograd (for forward_lower).""" rng = np.random.default_rng(GLOBAL_SEED) _, _, nnei = self.nlist.shape davg = rng.normal(size=(self.nt, nnei, 4)) - dstd = 0.1 + np.abs(rng.normal(size=(self.nt, nnei, 4))) - - for idt, prec in itertools.product([False, True], ["float64", "float32"]): - dtype = PRECISION_DICT[prec] - rtol, atol = get_tols(prec) - dd0 = DescrptYourName( - self.rcut, - self.rcut_smth, - self.sel, - precision=prec, - resnet_dt=idt, - seed=GLOBAL_SEED, - ).to(self.device) - dd0.davg = torch.tensor(davg, dtype=dtype, device=self.device) - dd0.dstd = torch.tensor(dstd, dtype=dtype, device=self.device) - dd0 = dd0.eval() - - coord_ext = torch.tensor(self.coord_ext, dtype=dtype, device=self.device) - atype_ext = torch.tensor(self.atype_ext, dtype=int, device=self.device) - nlist = torch.tensor(self.nlist, dtype=int, device=self.device) - - def fn(coord_ext, atype_ext, nlist): - coord_ext = coord_ext.detach().requires_grad_(True) - rd = dd0(coord_ext, atype_ext, nlist)[0] - grad = torch.autograd.grad(rd.sum(), coord_ext, create_graph=False)[0] - return rd, grad - - rd_eager, grad_eager = fn(coord_ext, atype_ext, nlist) - traced = make_fx(fn)(coord_ext, atype_ext, nlist) - rd_traced, grad_traced = traced(coord_ext, atype_ext, nlist) - np.testing.assert_allclose( - rd_eager.detach().cpu().numpy(), - rd_traced.detach().cpu().numpy(), - rtol=rtol, - atol=atol, - ) - np.testing.assert_allclose( - grad_eager.detach().cpu().numpy(), - grad_traced.detach().cpu().numpy(), - rtol=rtol, - atol=atol, - ) + dstd = rng.normal(size=(self.nt, nnei, 4)) + dstd = 0.1 + np.abs(dstd) + + dtype = PRECISION_DICT[prec] + rtol, atol = get_tols(prec) + dd0 = DescrptYourName( + self.rcut, + self.rcut_smth, + self.sel, + precision=prec, + seed=GLOBAL_SEED, + ).to(self.device) + dd0.davg = torch.tensor(davg, dtype=dtype, device=self.device) + dd0.dstd = torch.tensor(dstd, dtype=dtype, device=self.device) + dd0 = dd0.eval() + + coord_ext = torch.tensor(self.coord_ext, dtype=dtype, device=self.device) + atype_ext = torch.tensor(self.atype_ext, dtype=int, device=self.device) + nlist = torch.tensor(self.nlist, dtype=int, device=self.device) + + def fn(coord_ext, atype_ext, nlist): + coord_ext = coord_ext.detach().requires_grad_(True) + rd = dd0(coord_ext, atype_ext, nlist)[0] + grad = torch.autograd.grad(rd.sum(), coord_ext, create_graph=False)[0] + return rd, grad + + rd_eager, grad_eager = fn(coord_ext, atype_ext, nlist) + traced = make_fx(fn)(coord_ext, atype_ext, nlist) + rd_traced, grad_traced = traced(coord_ext, atype_ext, nlist) + np.testing.assert_allclose( + rd_eager.detach().cpu().numpy(), + rd_traced.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + ) + np.testing.assert_allclose( + grad_eager.detach().cpu().numpy(), + grad_traced.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + ) ``` Reference: `source/tests/pt_expt/descriptor/test_se_t.py` @@ -200,7 +214,9 @@ Reference: `source/tests/pt_expt/descriptor/test_se_t.py` ```python from typing import Any + from deepmd.dpmodel.descriptor.your_name import DescrptYourName as DescrptYourNameDP + from ..common import to_array_api_strict_array from ..utils.exclude_mask import PairExcludeMask from ..utils.network import NetworkCollection @@ -230,13 +246,18 @@ Reference: `source/tests/array_api_strict/descriptor/se_e2_r.py` **Create** `source/tests/consistent/descriptor/test_.py` +Two test classes: one for numerical consistency (`CommonTest`), one for API consistency (`DescriptorAPITest`). + ```python import unittest from typing import Any + import numpy as np + from deepmd.dpmodel.descriptor.your_name import DescrptYourName as DescrptYourNameDP from deepmd.env import GLOBAL_NP_FLOAT_PRECISION from deepmd.utils.argcheck import descrpt_your_name_args + from ..common import ( INSTALLED_ARRAY_API_STRICT, INSTALLED_JAX, @@ -246,7 +267,7 @@ from ..common import ( CommonTest, parameterized, ) -from .common import DescriptorTest +from .common import DescriptorAPITest, DescriptorTest # Conditional imports for each backend if INSTALLED_PT: @@ -292,25 +313,26 @@ class TestYourName(CommonTest, DescriptorTest, unittest.TestCase): "activation_function": "relu", } - # Set skip_* properties based on which backends are available @property - def skip_pt(self): + def skip_pt(self) -> bool: return not INSTALLED_PT @property - def skip_pt_expt(self): + def skip_pt_expt(self) -> bool: + # Add parameter-based skips here if needed, e.g.: + # return (not some_supported_param) or CommonTest.skip_pt_expt return CommonTest.skip_pt_expt @property - def skip_dp(self): + def skip_dp(self) -> bool: return CommonTest.skip_dp @property - def skip_jax(self): + def skip_jax(self) -> bool: return not INSTALLED_JAX @property - def skip_array_api_strict(self): + def skip_array_api_strict(self) -> bool: return not INSTALLED_ARRAY_API_STRICT tf_class = YourNameTF @@ -354,38 +376,37 @@ class TestYourName(CommonTest, DescriptorTest, unittest.TestCase): ) self.natoms = np.array([6, 6, 2, 4], dtype=np.int32) - # Implement eval_* methods using self.eval_*_descriptor() helpers - def eval_dp(self, dp_obj): + # Implement eval_* methods using self.eval_*_descriptor() helpers. + # For mixed_types descriptors (dpa1, dpa2, dpa3, se_atten_v2), + # pass mixed_types=True to each eval call. + def eval_dp(self, dp_obj: Any) -> Any: return self.eval_dp_descriptor( dp_obj, self.natoms, self.coords, self.atype, self.box ) - def eval_pt(self, pt_obj): + def eval_pt(self, pt_obj: Any) -> Any: return self.eval_pt_descriptor( pt_obj, self.natoms, self.coords, self.atype, self.box ) - def eval_pt_expt(self, pt_expt_obj): + def eval_pt_expt(self, pt_expt_obj: Any) -> Any: return self.eval_pt_expt_descriptor( pt_expt_obj, self.natoms, self.coords, self.atype, self.box ) - def eval_jax(self, jax_obj): + def eval_jax(self, jax_obj: Any) -> Any: return self.eval_jax_descriptor( jax_obj, self.natoms, self.coords, self.atype, self.box ) - def eval_array_api_strict(self, obj): + def eval_array_api_strict(self, obj: Any) -> Any: return self.eval_array_api_strict_descriptor( obj, self.natoms, self.coords, self.atype, self.box ) - def extract_ret(self, ret, backend): + def extract_ret(self, ret: Any, backend: Any) -> tuple: return (ret[0],) - # For mixed_types descriptors (dpa1, dpa2, dpa3, se_atten_v2), - # pass mixed_types=True to eval_*_descriptor calls. - @property def rtol(self) -> float: _, precision = self.param @@ -395,6 +416,37 @@ class TestYourName(CommonTest, DescriptorTest, unittest.TestCase): def atol(self) -> float: _, precision = self.param return 1e-10 if precision == "float64" else 1e-4 + + +@parameterized( + ("float64",), # precision — API test only needs one precision +) +class TestYourNameAPI(DescriptorAPITest, unittest.TestCase): + @property + def data(self) -> dict: + (precision,) = self.param + return { + "sel": [9, 10], + "rcut_smth": 5.80, + "rcut": 6.00, + "neuron": [6, 12, 24], + "precision": precision, + "seed": 1145141919810, + } + + dp_class = DescrptYourNameDP + pt_class = YourNamePT + pt_expt_class = YourNamePTExpt + args = descrpt_your_name_args() + ntypes = 2 + + @property + def skip_pt(self) -> bool: + return not INSTALLED_PT + + @property + def skip_pt_expt(self) -> bool: + return not INSTALLED_PT_EXPT ``` -Reference: `source/tests/consistent/descriptor/test_se_r.py` +Reference: `source/tests/consistent/descriptor/test_se_t.py` From 0d29f70ad1c577ca134b630863053cca6b096533 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Thu, 19 Feb 2026 14:31:42 +0800 Subject: [PATCH 4/8] update --- .../references/dpmodel-implementation.md | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/skills/add-descriptor/references/dpmodel-implementation.md b/skills/add-descriptor/references/dpmodel-implementation.md index f4717b41a2..21c26b96ad 100644 --- a/skills/add-descriptor/references/dpmodel-implementation.md +++ b/skills/add-descriptor/references/dpmodel-implementation.md @@ -70,6 +70,20 @@ xp = array_api_compat.array_namespace(coord_ext) device = array_api_compat.device(coord_ext) ``` +To check whether a method is within the [array API standard](https://data-apis.org/array-api/), use the following command (query `zeros_like` for example): + +```sh +uvx --from array-api-strict python -c "import array_api_strict,pydoc;print(pydoc.render_doc(array_api_strict.zeros_like))" +``` + +If the method exists, its doc will be printed; otherwise, `AttributeError` is thrown. + +For methods of an `Array` class, call (query `Array.shape` for example): + +```sh +uvx --from array-api-strict python -c "import array_api_strict,pydoc;print(pydoc.render_doc(array_api_strict._array_object.Array.shape))" +``` + Rules: 1. **Never use `np.einsum` on arrays that might be torch tensors** — torch disables `__array_function__` so `np.einsum` fails on tensors with `requires_grad=True`. Use `xp.sum` with broadcasting: From e937a75311999f46a99264a3a3411d61b0291f19 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Thu, 19 Feb 2026 14:44:08 +0800 Subject: [PATCH 5/8] fixes --- skills/add-descriptor/SKILL.md | 8 +++++--- skills/add-descriptor/references/test-patterns.md | 9 ++++++++- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/skills/add-descriptor/SKILL.md b/skills/add-descriptor/SKILL.md index 7d50b84a3b..7709687947 100644 --- a/skills/add-descriptor/SKILL.md +++ b/skills/add-descriptor/SKILL.md @@ -75,6 +75,8 @@ Pattern: `@flax_module` decorator + custom `__setattr__` for attribute conversio from deepmd.dpmodel.descriptor.your_name import DescrptYourName as DescrptYourNameDP from deepmd.jax.common import ArrayAPIVariable, flax_module, to_jax_array from deepmd.jax.descriptor.base_descriptor import BaseDescriptor +from deepmd.jax.utils.exclude_mask import PairExcludeMask +from deepmd.jax.utils.network import NetworkCollection @BaseDescriptor.register("your_name") @@ -170,7 +172,7 @@ Reference: `deepmd/pd/model/descriptor/se_a.py` ## Step 7: Write tests -Seven test categories. See [references/test-patterns.md](references/test-patterns.md) for full code templates. +Eight test categories. See [references/test-patterns.md](references/test-patterns.md) for full code templates. pt_expt tests use `pytest.mark.parametrize` (not `itertools.product`), do not inherit from `unittest.TestCase`, and use `setup_method` (not `setUp`). @@ -181,8 +183,8 @@ pt_expt tests use `pytest.mark.parametrize` (not `itertools.product`), do not in | 7c. PT | `source/tests/pt/model/test_descriptor_.py` | PT hard-coded tests (if applicable) | | 7d. PD | `source/tests/pd/model/test_descriptor_.py` | PD hard-coded tests (if applicable) | | 7e. array_api_strict | `source/tests/array_api_strict/descriptor/.py` | Wrapper for consistency tests | -| 7f. Universal dpmodel | `source/tests/universal/dpmodel/descriptor/test_descriptor.py` | Add parameterized entry | -| 7g. Universal PT | `source/tests/universal/pt/descriptor/test_descriptor.py` | Add parameterized entry | +| 7f. Universal dpmodel | `source/tests/universal/dpmodel/descriptor/test_descriptor.py` | Add parametrized entry | +| 7g. Universal PT | `source/tests/universal/pt/descriptor/test_descriptor.py` | Add parametrized entry | | 7h. Consistency | `source/tests/consistent/descriptor/test_.py` | Cross-backend + API consistency | ## Verification diff --git a/skills/add-descriptor/references/test-patterns.md b/skills/add-descriptor/references/test-patterns.md index 0c58883f46..1b24e7d241 100644 --- a/skills/add-descriptor/references/test-patterns.md +++ b/skills/add-descriptor/references/test-patterns.md @@ -88,7 +88,7 @@ class TestDescrptYourName(TestCaseSingleFrameWithNlist): ).to(self.device) dd0.davg = torch.tensor(davg, dtype=dtype, device=self.device) dd0.dstd = torch.tensor(dstd, dtype=dtype, device=self.device) - rd0, _, _, _, _ = dd0( + rd0, _, _, _, sw0 = dd0( torch.tensor(self.coord_ext, dtype=dtype, device=self.device), torch.tensor(self.atype_ext, dtype=int, device=self.device), torch.tensor(self.nlist, dtype=int, device=self.device), @@ -107,6 +107,13 @@ class TestDescrptYourName(TestCaseSingleFrameWithNlist): atol=atol, err_msg=err_msg, ) + np.testing.assert_allclose( + sw0.detach().cpu().numpy(), + sw1.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) # Permutation equivariance np.testing.assert_allclose( rd0.detach().cpu().numpy()[0][self.perm[: self.nloc]], From 4cf4455b52f4949fcf7dfea3f1244cfd64d5c6f9 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Thu, 19 Feb 2026 20:22:00 +0800 Subject: [PATCH 6/8] fix --- skills/add-descriptor/SKILL.md | 54 +++++++++++-------- .../references/test-patterns.md | 21 ++++---- 2 files changed, 45 insertions(+), 30 deletions(-) diff --git a/skills/add-descriptor/SKILL.md b/skills/add-descriptor/SKILL.md index 7709687947..98163d745d 100644 --- a/skills/add-descriptor/SKILL.md +++ b/skills/add-descriptor/SKILL.md @@ -162,7 +162,17 @@ PT descriptors are fully reimplemented in PyTorch (not wrapping dpmodel). They i Reference: `deepmd/pt/model/descriptor/se_a.py` -## Step 6: Hard-code for PD backend (if needed) +## Step 6: Hard-code for TF backend (if needed) + +**Create** `deepmd/tf/descriptor/.py` + +TF descriptors are fully reimplemented in TensorFlow. They inherit from `BaseDescriptor` and implement the TF computation graph. + +**Edit** `deepmd/tf/descriptor/__init__.py` — add import. + +Reference: `deepmd/tf/descriptor/se_a.py` + +## Step 7: Hard-code for PD backend (if needed) Same as PT but using Paddle. Inherit from `BaseDescriptor` and `paddle.nn.Layer`. @@ -170,7 +180,7 @@ Same as PT but using Paddle. Inherit from `BaseDescriptor` and `paddle.nn.Layer` Reference: `deepmd/pd/model/descriptor/se_a.py` -## Step 7: Write tests +## Step 8: Write tests Eight test categories. See [references/test-patterns.md](references/test-patterns.md) for full code templates. @@ -178,14 +188,14 @@ pt_expt tests use `pytest.mark.parametrize` (not `itertools.product`), do not in | Test | File | Purpose | | --------------------- | -------------------------------------------------------------- | ------------------------------------------------- | -| 7a. dpmodel | `source/tests/common/dpmodel/test_descriptor_.py` | Serialize/deserialize round-trip | -| 7b. pt_expt | `source/tests/pt_expt/descriptor/test_.py` | Consistency + exportable + make_fx (float64 only) | -| 7c. PT | `source/tests/pt/model/test_descriptor_.py` | PT hard-coded tests (if applicable) | -| 7d. PD | `source/tests/pd/model/test_descriptor_.py` | PD hard-coded tests (if applicable) | -| 7e. array_api_strict | `source/tests/array_api_strict/descriptor/.py` | Wrapper for consistency tests | -| 7f. Universal dpmodel | `source/tests/universal/dpmodel/descriptor/test_descriptor.py` | Add parametrized entry | -| 7g. Universal PT | `source/tests/universal/pt/descriptor/test_descriptor.py` | Add parametrized entry | -| 7h. Consistency | `source/tests/consistent/descriptor/test_.py` | Cross-backend + API consistency | +| 8a. dpmodel | `source/tests/common/dpmodel/test_descriptor_.py` | Serialize/deserialize round-trip | +| 8b. pt_expt | `source/tests/pt_expt/descriptor/test_.py` | Consistency + exportable + make_fx (float64 only) | +| 8c. PT | `source/tests/pt/model/test_descriptor_.py` | PT hard-coded tests (if applicable) | +| 8d. PD | `source/tests/pd/model/test_descriptor_.py` | PD hard-coded tests (if applicable) | +| 8e. array_api_strict | `source/tests/array_api_strict/descriptor/.py` | Wrapper for consistency tests | +| 8f. Universal dpmodel | `source/tests/universal/dpmodel/descriptor/test_descriptor.py` | Add parametrized entry | +| 8g. Universal PT | `source/tests/universal/pt/descriptor/test_descriptor.py` | Add parametrized entry | +| 8h. Consistency | `source/tests/consistent/descriptor/test_.py` | Cross-backend + API consistency | ## Verification @@ -225,14 +235,16 @@ print('Round-trip OK:', d.get_dim_out() == d2.get_dim_out()) | 4 | Edit | `deepmd/pt_expt/descriptor/__init__.py` | | 5 | Create | `deepmd/pt/model/descriptor/.py` (if needed) | | 5 | Edit | `deepmd/pt/model/descriptor/__init__.py` (if needed) | -| 6 | Create | `deepmd/pd/model/descriptor/.py` (if needed) | -| 6 | Edit | `deepmd/pd/model/descriptor/__init__.py` (if needed) | -| 7a | Create | `source/tests/common/dpmodel/test_descriptor_.py` | -| 7b | Create | `source/tests/pt_expt/descriptor/test_.py` | -| 7c | Create | `source/tests/pt/model/test_descriptor_.py` (if PT) | -| 7d | Create | `source/tests/pd/model/test_descriptor_.py` (if PD) | -| 7e | Create | `source/tests/array_api_strict/descriptor/.py` | -| 7e | Edit | `source/tests/array_api_strict/descriptor/__init__.py` | -| 7f | Edit | `source/tests/universal/dpmodel/descriptor/test_descriptor.py` | -| 7g | Edit | `source/tests/universal/pt/descriptor/test_descriptor.py` | -| 7h | Create | `source/tests/consistent/descriptor/test_.py` | +| 6 | Create | `deepmd/tf/descriptor/.py` (if needed) | +| 6 | Edit | `deepmd/tf/descriptor/__init__.py` (if needed) | +| 7 | Create | `deepmd/pd/model/descriptor/.py` (if needed) | +| 7 | Edit | `deepmd/pd/model/descriptor/__init__.py` (if needed) | +| 8a | Create | `source/tests/common/dpmodel/test_descriptor_.py` | +| 8b | Create | `source/tests/pt_expt/descriptor/test_.py` | +| 8c | Create | `source/tests/pt/model/test_descriptor_.py` (if PT) | +| 8d | Create | `source/tests/pd/model/test_descriptor_.py` (if PD) | +| 8e | Create | `source/tests/array_api_strict/descriptor/.py` | +| 8e | Edit | `source/tests/array_api_strict/descriptor/__init__.py` | +| 8f | Edit | `source/tests/universal/dpmodel/descriptor/test_descriptor.py` | +| 8g | Edit | `source/tests/universal/pt/descriptor/test_descriptor.py` | +| 8h | Create | `source/tests/consistent/descriptor/test_.py` | diff --git a/skills/add-descriptor/references/test-patterns.md b/skills/add-descriptor/references/test-patterns.md index 1b24e7d241..cc1145b820 100644 --- a/skills/add-descriptor/references/test-patterns.md +++ b/skills/add-descriptor/references/test-patterns.md @@ -1,6 +1,6 @@ # Test Patterns for Descriptors -## 7a. dpmodel self-consistency test +## 8a. dpmodel self-consistency test **Create** `source/tests/common/dpmodel/test_descriptor_.py` @@ -25,7 +25,9 @@ class TestDescrptYourName(unittest.TestCase, TestCaseSingleFrameWithNlist): davg = rng.normal( size=(self.nt, nnei, 4) ) # 4 for full env mat, 1 for radial-only - dstd = 0.1 + np.abs(rng.normal(size=(self.nt, nnei, 4))) + dstd = 0.1 + np.abs( + rng.normal(size=(self.nt, nnei, 4)) + ) # 4 for full env mat, 1 for radial-only em0 = DescrptYourName(self.rcut, self.rcut_smth, self.sel) em0.davg = davg @@ -39,7 +41,7 @@ class TestDescrptYourName(unittest.TestCase, TestCaseSingleFrameWithNlist): Reference: `source/tests/common/dpmodel/test_descriptor_se_t.py` -## 7b. pt_expt unit tests +## 8b. pt_expt unit tests **Create** `source/tests/pt_expt/descriptor/test_.py` @@ -215,7 +217,7 @@ class TestDescrptYourName(TestCaseSingleFrameWithNlist): Reference: `source/tests/pt_expt/descriptor/test_se_t.py` -## 7e. array_api_strict wrapper +## 8e. array_api_strict wrapper **Create** `source/tests/array_api_strict/descriptor/.py` @@ -249,7 +251,7 @@ class DescrptYourName(DescrptYourNameDP): Reference: `source/tests/array_api_strict/descriptor/se_e2_r.py` -## 7h. Cross-backend consistency test +## 8h. Cross-backend consistency test **Create** `source/tests/consistent/descriptor/test_.py` @@ -276,7 +278,8 @@ from ..common import ( ) from .common import DescriptorAPITest, DescriptorTest -# Conditional imports for each backend +# Conditional imports for each backend. +# Omit any backend that has no implementation for this descriptor. if INSTALLED_PT: from deepmd.pt.model.descriptor.your_name import DescrptYourName as YourNamePT else: @@ -322,7 +325,7 @@ class TestYourName(CommonTest, DescriptorTest, unittest.TestCase): @property def skip_pt(self) -> bool: - return not INSTALLED_PT + return CommonTest.skip_pt @property def skip_pt_expt(self) -> bool: @@ -336,11 +339,11 @@ class TestYourName(CommonTest, DescriptorTest, unittest.TestCase): @property def skip_jax(self) -> bool: - return not INSTALLED_JAX + return CommonTest.skip_jax @property def skip_array_api_strict(self) -> bool: - return not INSTALLED_ARRAY_API_STRICT + return CommonTest.skip_array_api_strict tf_class = YourNameTF dp_class = DescrptYourNameDP From 5304608d3995710b67e73f1a5f58670f80b6e843 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Sat, 21 Feb 2026 22:50:39 +0800 Subject: [PATCH 7/8] add doc for new descriptors --- skills/add-descriptor/SKILL.md | 61 ++++++++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/skills/add-descriptor/SKILL.md b/skills/add-descriptor/SKILL.md index 98163d745d..a7529dcf2b 100644 --- a/skills/add-descriptor/SKILL.md +++ b/skills/add-descriptor/SKILL.md @@ -197,6 +197,65 @@ pt_expt tests use `pytest.mark.parametrize` (not `itertools.product`), do not in | 8g. Universal PT | `source/tests/universal/pt/descriptor/test_descriptor.py` | Add parametrized entry | | 8h. Consistency | `source/tests/consistent/descriptor/test_.py` | Cross-backend + API consistency | +## Step 9: Write documentation + +**Create** `doc/model/.md` + +Each descriptor needs a documentation page in `doc/model/`. Use MyST Markdown format with Sphinx extensions. List supported backends using icon substitutions. + +Template: + +````markdown +# Descriptor `"your_name"` {{ pytorch_icon }} {{ dpmodel_icon }} + +:::{note} +**Supported backends**: PyTorch {{ pytorch_icon }}, DP {{ dpmodel_icon }} +::: + +Brief description of what the descriptor is and its theoretical motivation. + +## Theory + +Mathematical formulation using LaTeX: + +```math + \mathcal{D}^i = ... +``` + +## Instructions + +Example JSON configuration: + +```json +"descriptor": { + "type": "your_name", + "sel": [46, 92], + "rcut_smth": 0.50, + "rcut": 6.00, + "neuron": [10, 20, 40], + "resnet_dt": false, + "seed": 1 +} +``` + +Explain key parameters and link to the argument schema using `{ref}` directives, +e.g. `{ref}rcut `. +```` + +Available backend icons: `{{ tensorflow_icon }}`, `{{ pytorch_icon }}`, `{{ jax_icon }}`, `{{ paddle_icon }}`, `{{ dpmodel_icon }}`. Only list backends that actually support this descriptor. + +**Edit** `doc/model/index.rst` — add the new page to the `toctree`: + +```rst +.. toctree:: + :maxdepth: 1 + + ... + +``` + +**Reference docs**: `doc/model/train-se-e2-r.md` (simple), `doc/model/dpa2.md` (modern) + ## Verification ```bash @@ -248,3 +307,5 @@ print('Round-trip OK:', d.get_dim_out() == d2.get_dim_out()) | 8f | Edit | `source/tests/universal/dpmodel/descriptor/test_descriptor.py` | | 8g | Edit | `source/tests/universal/pt/descriptor/test_descriptor.py` | | 8h | Create | `source/tests/consistent/descriptor/test_.py` | +| 9 | Create | `doc/model/.md` | +| 9 | Edit | `doc/model/index.rst` | From c5cc5c599ff293c12ab6e5cf23d618816286ce12 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Sat, 21 Feb 2026 23:01:42 +0800 Subject: [PATCH 8/8] mv skills to .github --- {skills => .github/skills}/add-descriptor/SKILL.md | 0 .../skills}/add-descriptor/references/dpmodel-implementation.md | 0 .../skills}/add-descriptor/references/test-patterns.md | 0 3 files changed, 0 insertions(+), 0 deletions(-) rename {skills => .github/skills}/add-descriptor/SKILL.md (100%) rename {skills => .github/skills}/add-descriptor/references/dpmodel-implementation.md (100%) rename {skills => .github/skills}/add-descriptor/references/test-patterns.md (100%) diff --git a/skills/add-descriptor/SKILL.md b/.github/skills/add-descriptor/SKILL.md similarity index 100% rename from skills/add-descriptor/SKILL.md rename to .github/skills/add-descriptor/SKILL.md diff --git a/skills/add-descriptor/references/dpmodel-implementation.md b/.github/skills/add-descriptor/references/dpmodel-implementation.md similarity index 100% rename from skills/add-descriptor/references/dpmodel-implementation.md rename to .github/skills/add-descriptor/references/dpmodel-implementation.md diff --git a/skills/add-descriptor/references/test-patterns.md b/.github/skills/add-descriptor/references/test-patterns.md similarity index 100% rename from skills/add-descriptor/references/test-patterns.md rename to .github/skills/add-descriptor/references/test-patterns.md