Skip to content

Commit ccad0fb

Browse files
wanghan-iapcmHan Wang
andauthored
feat: add skills for adding new descriptors (#5249)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Documentation** * Added an end-to-end guide for adding new descriptors, including dpmodel implementation requirements, reference patterns, registration/serialization guidance, backend wrapper patterns, and array‑API compatibility rules. * **Tests** * Added comprehensive test scaffolds: serialization round‑trip, self‑consistency, export/tracing scenarios, array‑API‑strict wrappers, and multi‑backend consistency checks across backends. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Han Wang <wang_han@iapcm.ac.cn>
1 parent 8ed49be commit ccad0fb

File tree

3 files changed

+891
-0
lines changed

3 files changed

+891
-0
lines changed
Lines changed: 311 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,311 @@
1+
---
2+
name: add-descriptor
3+
description: Guides through adding a new descriptor type to deepmd-kit. Covers implementing in dpmodel (array-API-compatible), wrapping for JAX/pt_expt backends, hard-coding for PT/PD, registering arguments, and writing all required tests.
4+
license: LGPL-3.0-or-later
5+
compatibility: Requires Python 3.10+, numpy, pytest. Optional backends for full testing (torch, jax, paddle).
6+
metadata:
7+
author: deepmd-kit
8+
version: "2.0"
9+
---
10+
11+
# Adding a New Descriptor to deepmd-kit
12+
13+
Follow these steps in order. Each step lists files to create/modify and patterns to follow.
14+
15+
## Step 1: Implement in dpmodel
16+
17+
**Create** `deepmd/dpmodel/descriptor/<name>.py`
18+
19+
Inherit from `NativeOP` and `BaseDescriptor`. Register with decorators:
20+
21+
```python
22+
from deepmd.dpmodel import NativeOP
23+
from .base_descriptor import BaseDescriptor
24+
25+
26+
@BaseDescriptor.register("your_name")
27+
@BaseDescriptor.register("alias_name") # optional aliases
28+
class DescrptYourName(NativeOP, BaseDescriptor): ...
29+
```
30+
31+
Key requirements:
32+
33+
- `__init__`: initialize cutoff, sel, networks, davg/dstd statistics
34+
- `call(coord_ext, atype_ext, nlist, mapping=None)`: forward pass returning `(descriptor, rot_mat, g2, h2, sw)`
35+
- `serialize() -> dict`: save with `@class`, `type`, `@version`, `@variables` keys
36+
- `deserialize(cls, data)`: reconstruct from dict
37+
- Property/getter methods: `get_rcut`, `get_sel`, `get_dim_out`, `mixed_types`, etc.
38+
- `__getitem__`/`__setitem__` for `davg`/`dstd` access via multiple key aliases
39+
40+
All dpmodel code **must** use `array_api_compat` for cross-backend compatibility (numpy/torch/jax/paddle). See [references/dpmodel-implementation.md](references/dpmodel-implementation.md) for full method table, array API pitfalls, and utilities.
41+
42+
**Reference implementations**:
43+
44+
- Simple: `deepmd/dpmodel/descriptor/se_e2_a.py`
45+
- Three-body: `deepmd/dpmodel/descriptor/se_t.py`
46+
- Attention-based: `deepmd/dpmodel/descriptor/dpa1.py`
47+
48+
## Step 2: Register
49+
50+
**Edit** `deepmd/dpmodel/descriptor/__init__.py` — add import and `__all__` entry.
51+
52+
**Edit** `deepmd/utils/argcheck.py` — register descriptor arguments:
53+
54+
```python
55+
@descrpt_args_plugin.register("your_name", alias=["alias"], doc="Description")
56+
def descrpt_your_name_args() -> list[Argument]:
57+
return [
58+
Argument("sel", [list[int], str], optional=True, default="auto", doc=doc_sel),
59+
Argument("rcut", float, optional=True, default=6.0, doc=doc_rcut),
60+
Argument("rcut_smth", float, optional=True, default=0.5, doc=doc_rcut_smth),
61+
Argument(
62+
"neuron", list[int], optional=True, default=[10, 20, 40], doc=doc_neuron
63+
),
64+
# ... add all constructor parameters
65+
]
66+
```
67+
68+
## Step 3: Wrap for JAX backend
69+
70+
**Create** `deepmd/jax/descriptor/<name>.py`
71+
72+
Pattern: `@flax_module` decorator + custom `__setattr__` for attribute conversion.
73+
74+
```python
75+
from deepmd.dpmodel.descriptor.your_name import DescrptYourName as DescrptYourNameDP
76+
from deepmd.jax.common import ArrayAPIVariable, flax_module, to_jax_array
77+
from deepmd.jax.descriptor.base_descriptor import BaseDescriptor
78+
from deepmd.jax.utils.exclude_mask import PairExcludeMask
79+
from deepmd.jax.utils.network import NetworkCollection
80+
81+
82+
@BaseDescriptor.register("your_name")
83+
@flax_module
84+
class DescrptYourName(DescrptYourNameDP):
85+
def __setattr__(self, name, value):
86+
if name in {"davg", "dstd"}:
87+
value = to_jax_array(value)
88+
if value is not None:
89+
value = ArrayAPIVariable(value)
90+
elif name in {"embeddings"}:
91+
if value is not None:
92+
value = NetworkCollection.deserialize(value.serialize())
93+
elif name == "env_mat":
94+
pass # stateless
95+
elif name == "emask":
96+
value = PairExcludeMask(value.ntypes, value.exclude_types)
97+
return super().__setattr__(name, value)
98+
```
99+
100+
For nested sub-components, define wrapper classes bottom-up. See `deepmd/jax/descriptor/dpa1.py` for example.
101+
102+
**Edit** `deepmd/jax/descriptor/__init__.py` — add import and `__all__` entry.
103+
104+
## Step 4: Wrap for pt_expt backend
105+
106+
**Create** `deepmd/pt_expt/descriptor/<name>.py`
107+
108+
The `@torch_module` decorator handles everything automatically:
109+
110+
- Auto-generates `forward()` delegating to `call()` (and `forward_lower()` from `call_lower()`)
111+
- Auto-generates `__setattr__` that converts numpy arrays to torch buffers and dpmodel objects to pt_expt modules via a converter registry
112+
- Any unregistered `NativeOP` assigned as an attribute will raise `TypeError` — register it first
113+
114+
Simple descriptors (no custom sub-components) need only an empty body:
115+
116+
```python
117+
from deepmd.dpmodel.descriptor.your_name import DescrptYourName as DescrptYourNameDP
118+
from deepmd.pt_expt.common import torch_module
119+
from deepmd.pt_expt.descriptor.base_descriptor import BaseDescriptor
120+
121+
122+
@BaseDescriptor.register("your_name")
123+
@torch_module
124+
class DescrptYourName(DescrptYourNameDP):
125+
pass
126+
```
127+
128+
Standard dpmodel sub-components (`NetworkCollection`, `EmbeddingNet`, `PairExcludeMask`, `EnvMat`, `TypeEmbedNet`) are pre-registered in `deepmd/pt_expt/utils/` and converted automatically. No `__setattr__` override needed.
129+
130+
For **custom sub-components** (e.g., a new block class inheriting `NativeOP`), create a separate wrapper file and register bottom-up with `register_dpmodel_mapping`:
131+
132+
```python
133+
# deepmd/pt_expt/descriptor/your_block.py
134+
from deepmd.dpmodel.descriptor.your_block import YourBlock as YourBlockDP
135+
from deepmd.pt_expt.common import register_dpmodel_mapping, torch_module
136+
137+
138+
@torch_module
139+
class YourBlock(YourBlockDP):
140+
pass
141+
142+
143+
register_dpmodel_mapping(
144+
YourBlockDP,
145+
lambda v: YourBlock.deserialize(v.serialize()),
146+
)
147+
```
148+
149+
Then import this module in `deepmd/pt_expt/descriptor/__init__.py` for its side effect (the registration must happen before the parent descriptor is instantiated).
150+
151+
Reference: `deepmd/pt_expt/descriptor/se_t_tebd.py` + `se_t_tebd_block.py`
152+
153+
**Edit** `deepmd/pt_expt/descriptor/__init__.py` — add import and `__all__` entry.
154+
155+
## Step 5: Hard-code for PT backend (if needed)
156+
157+
**Create** `deepmd/pt/model/descriptor/<name>.py`
158+
159+
PT descriptors are fully reimplemented in PyTorch (not wrapping dpmodel). They inherit from `BaseDescriptor` and `torch.nn.Module`. Must implement `forward()`, `serialize()`, `deserialize()`.
160+
161+
**Edit** `deepmd/pt/model/descriptor/__init__.py` — add import.
162+
163+
Reference: `deepmd/pt/model/descriptor/se_a.py`
164+
165+
## Step 6: Hard-code for TF backend (if needed)
166+
167+
**Create** `deepmd/tf/descriptor/<name>.py`
168+
169+
TF descriptors are fully reimplemented in TensorFlow. They inherit from `BaseDescriptor` and implement the TF computation graph.
170+
171+
**Edit** `deepmd/tf/descriptor/__init__.py` — add import.
172+
173+
Reference: `deepmd/tf/descriptor/se_a.py`
174+
175+
## Step 7: Hard-code for PD backend (if needed)
176+
177+
Same as PT but using Paddle. Inherit from `BaseDescriptor` and `paddle.nn.Layer`.
178+
179+
**Edit** `deepmd/pd/model/descriptor/__init__.py` — add import.
180+
181+
Reference: `deepmd/pd/model/descriptor/se_a.py`
182+
183+
## Step 8: Write tests
184+
185+
Eight test categories. See [references/test-patterns.md](references/test-patterns.md) for full code templates.
186+
187+
pt_expt tests use `pytest.mark.parametrize` (not `itertools.product`), do not inherit from `unittest.TestCase`, and use `setup_method` (not `setUp`).
188+
189+
| Test | File | Purpose |
190+
| --------------------- | -------------------------------------------------------------- | ------------------------------------------------- |
191+
| 8a. dpmodel | `source/tests/common/dpmodel/test_descriptor_<name>.py` | Serialize/deserialize round-trip |
192+
| 8b. pt_expt | `source/tests/pt_expt/descriptor/test_<name>.py` | Consistency + exportable + make_fx (float64 only) |
193+
| 8c. PT | `source/tests/pt/model/test_descriptor_<name>.py` | PT hard-coded tests (if applicable) |
194+
| 8d. PD | `source/tests/pd/model/test_descriptor_<name>.py` | PD hard-coded tests (if applicable) |
195+
| 8e. array_api_strict | `source/tests/array_api_strict/descriptor/<name>.py` | Wrapper for consistency tests |
196+
| 8f. Universal dpmodel | `source/tests/universal/dpmodel/descriptor/test_descriptor.py` | Add parametrized entry |
197+
| 8g. Universal PT | `source/tests/universal/pt/descriptor/test_descriptor.py` | Add parametrized entry |
198+
| 8h. Consistency | `source/tests/consistent/descriptor/test_<name>.py` | Cross-backend + API consistency |
199+
200+
## Step 9: Write documentation
201+
202+
**Create** `doc/model/<name>.md`
203+
204+
Each descriptor needs a documentation page in `doc/model/`. Use MyST Markdown format with Sphinx extensions. List supported backends using icon substitutions.
205+
206+
Template:
207+
208+
````markdown
209+
# Descriptor `"your_name"` {{ pytorch_icon }} {{ dpmodel_icon }}
210+
211+
:::{note}
212+
**Supported backends**: PyTorch {{ pytorch_icon }}, DP {{ dpmodel_icon }}
213+
:::
214+
215+
Brief description of what the descriptor is and its theoretical motivation.
216+
217+
## Theory
218+
219+
Mathematical formulation using LaTeX:
220+
221+
```math
222+
\mathcal{D}^i = ...
223+
```
224+
225+
## Instructions
226+
227+
Example JSON configuration:
228+
229+
```json
230+
"descriptor": {
231+
"type": "your_name",
232+
"sel": [46, 92],
233+
"rcut_smth": 0.50,
234+
"rcut": 6.00,
235+
"neuron": [10, 20, 40],
236+
"resnet_dt": false,
237+
"seed": 1
238+
}
239+
```
240+
241+
Explain key parameters and link to the argument schema using `{ref}` directives,
242+
e.g. `{ref}rcut <model[standard]/descriptor[your_name]/rcut>`.
243+
````
244+
245+
Available backend icons: `{{ tensorflow_icon }}`, `{{ pytorch_icon }}`, `{{ jax_icon }}`, `{{ paddle_icon }}`, `{{ dpmodel_icon }}`. Only list backends that actually support this descriptor.
246+
247+
**Edit** `doc/model/index.rst` — add the new page to the `toctree`:
248+
249+
```rst
250+
.. toctree::
251+
:maxdepth: 1
252+
253+
...
254+
<name>
255+
```
256+
257+
**Reference docs**: `doc/model/train-se-e2-r.md` (simple), `doc/model/dpa2.md` (modern)
258+
259+
## Verification
260+
261+
```bash
262+
# dpmodel self-consistency
263+
python -m pytest source/tests/common/dpmodel/test_descriptor_<name>.py -v
264+
265+
# pt_expt unit tests
266+
python -m pytest source/tests/pt_expt/descriptor/test_<name>.py -v
267+
268+
# Cross-backend consistency
269+
python -m pytest source/tests/consistent/descriptor/test_<name>.py -v
270+
271+
# PT/PD unit tests (if hard-coded)
272+
python -m pytest source/tests/pt/model/test_descriptor_<name>.py -v
273+
python -m pytest source/tests/pd/model/test_descriptor_<name>.py -v
274+
275+
# Quick smoke test
276+
python -c "
277+
from deepmd.dpmodel.descriptor import DescrptYourName
278+
d = DescrptYourName(rcut=6.0, rcut_smth=1.8, sel=[20, 20])
279+
d2 = DescrptYourName.deserialize(d.serialize())
280+
print('Round-trip OK:', d.get_dim_out() == d2.get_dim_out())
281+
"
282+
```
283+
284+
## Files summary
285+
286+
| Step | Action | File |
287+
| ---- | ------ | -------------------------------------------------------------- |
288+
| 1 | Create | `deepmd/dpmodel/descriptor/<name>.py` |
289+
| 2 | Edit | `deepmd/dpmodel/descriptor/__init__.py` |
290+
| 2 | Edit | `deepmd/utils/argcheck.py` |
291+
| 3 | Create | `deepmd/jax/descriptor/<name>.py` |
292+
| 3 | Edit | `deepmd/jax/descriptor/__init__.py` |
293+
| 4 | Create | `deepmd/pt_expt/descriptor/<name>.py` |
294+
| 4 | Edit | `deepmd/pt_expt/descriptor/__init__.py` |
295+
| 5 | Create | `deepmd/pt/model/descriptor/<name>.py` (if needed) |
296+
| 5 | Edit | `deepmd/pt/model/descriptor/__init__.py` (if needed) |
297+
| 6 | Create | `deepmd/tf/descriptor/<name>.py` (if needed) |
298+
| 6 | Edit | `deepmd/tf/descriptor/__init__.py` (if needed) |
299+
| 7 | Create | `deepmd/pd/model/descriptor/<name>.py` (if needed) |
300+
| 7 | Edit | `deepmd/pd/model/descriptor/__init__.py` (if needed) |
301+
| 8a | Create | `source/tests/common/dpmodel/test_descriptor_<name>.py` |
302+
| 8b | Create | `source/tests/pt_expt/descriptor/test_<name>.py` |
303+
| 8c | Create | `source/tests/pt/model/test_descriptor_<name>.py` (if PT) |
304+
| 8d | Create | `source/tests/pd/model/test_descriptor_<name>.py` (if PD) |
305+
| 8e | Create | `source/tests/array_api_strict/descriptor/<name>.py` |
306+
| 8e | Edit | `source/tests/array_api_strict/descriptor/__init__.py` |
307+
| 8f | Edit | `source/tests/universal/dpmodel/descriptor/test_descriptor.py` |
308+
| 8g | Edit | `source/tests/universal/pt/descriptor/test_descriptor.py` |
309+
| 8h | Create | `source/tests/consistent/descriptor/test_<name>.py` |
310+
| 9 | Create | `doc/model/<name>.md` |
311+
| 9 | Edit | `doc/model/index.rst` |

0 commit comments

Comments
 (0)