|
| 1 | +--- |
| 2 | +name: add-descriptor |
| 3 | +description: Guides through adding a new descriptor type to deepmd-kit. Covers implementing in dpmodel (array-API-compatible), wrapping for JAX/pt_expt backends, hard-coding for PT/PD, registering arguments, and writing all required tests. |
| 4 | +license: LGPL-3.0-or-later |
| 5 | +compatibility: Requires Python 3.10+, numpy, pytest. Optional backends for full testing (torch, jax, paddle). |
| 6 | +metadata: |
| 7 | + author: deepmd-kit |
| 8 | + version: "2.0" |
| 9 | +--- |
| 10 | + |
| 11 | +# Adding a New Descriptor to deepmd-kit |
| 12 | + |
| 13 | +Follow these steps in order. Each step lists files to create/modify and patterns to follow. |
| 14 | + |
| 15 | +## Step 1: Implement in dpmodel |
| 16 | + |
| 17 | +**Create** `deepmd/dpmodel/descriptor/<name>.py` |
| 18 | + |
| 19 | +Inherit from `NativeOP` and `BaseDescriptor`. Register with decorators: |
| 20 | + |
| 21 | +```python |
| 22 | +from deepmd.dpmodel import NativeOP |
| 23 | +from .base_descriptor import BaseDescriptor |
| 24 | + |
| 25 | + |
| 26 | +@BaseDescriptor.register("your_name") |
| 27 | +@BaseDescriptor.register("alias_name") # optional aliases |
| 28 | +class DescrptYourName(NativeOP, BaseDescriptor): ... |
| 29 | +``` |
| 30 | + |
| 31 | +Key requirements: |
| 32 | + |
| 33 | +- `__init__`: initialize cutoff, sel, networks, davg/dstd statistics |
| 34 | +- `call(coord_ext, atype_ext, nlist, mapping=None)`: forward pass returning `(descriptor, rot_mat, g2, h2, sw)` |
| 35 | +- `serialize() -> dict`: save with `@class`, `type`, `@version`, `@variables` keys |
| 36 | +- `deserialize(cls, data)`: reconstruct from dict |
| 37 | +- Property/getter methods: `get_rcut`, `get_sel`, `get_dim_out`, `mixed_types`, etc. |
| 38 | +- `__getitem__`/`__setitem__` for `davg`/`dstd` access via multiple key aliases |
| 39 | + |
| 40 | +All dpmodel code **must** use `array_api_compat` for cross-backend compatibility (numpy/torch/jax/paddle). See [references/dpmodel-implementation.md](references/dpmodel-implementation.md) for full method table, array API pitfalls, and utilities. |
| 41 | + |
| 42 | +**Reference implementations**: |
| 43 | + |
| 44 | +- Simple: `deepmd/dpmodel/descriptor/se_e2_a.py` |
| 45 | +- Three-body: `deepmd/dpmodel/descriptor/se_t.py` |
| 46 | +- Attention-based: `deepmd/dpmodel/descriptor/dpa1.py` |
| 47 | + |
| 48 | +## Step 2: Register |
| 49 | + |
| 50 | +**Edit** `deepmd/dpmodel/descriptor/__init__.py` — add import and `__all__` entry. |
| 51 | + |
| 52 | +**Edit** `deepmd/utils/argcheck.py` — register descriptor arguments: |
| 53 | + |
| 54 | +```python |
| 55 | +@descrpt_args_plugin.register("your_name", alias=["alias"], doc="Description") |
| 56 | +def descrpt_your_name_args() -> list[Argument]: |
| 57 | + return [ |
| 58 | + Argument("sel", [list[int], str], optional=True, default="auto", doc=doc_sel), |
| 59 | + Argument("rcut", float, optional=True, default=6.0, doc=doc_rcut), |
| 60 | + Argument("rcut_smth", float, optional=True, default=0.5, doc=doc_rcut_smth), |
| 61 | + Argument( |
| 62 | + "neuron", list[int], optional=True, default=[10, 20, 40], doc=doc_neuron |
| 63 | + ), |
| 64 | + # ... add all constructor parameters |
| 65 | + ] |
| 66 | +``` |
| 67 | + |
| 68 | +## Step 3: Wrap for JAX backend |
| 69 | + |
| 70 | +**Create** `deepmd/jax/descriptor/<name>.py` |
| 71 | + |
| 72 | +Pattern: `@flax_module` decorator + custom `__setattr__` for attribute conversion. |
| 73 | + |
| 74 | +```python |
| 75 | +from deepmd.dpmodel.descriptor.your_name import DescrptYourName as DescrptYourNameDP |
| 76 | +from deepmd.jax.common import ArrayAPIVariable, flax_module, to_jax_array |
| 77 | +from deepmd.jax.descriptor.base_descriptor import BaseDescriptor |
| 78 | +from deepmd.jax.utils.exclude_mask import PairExcludeMask |
| 79 | +from deepmd.jax.utils.network import NetworkCollection |
| 80 | + |
| 81 | + |
| 82 | +@BaseDescriptor.register("your_name") |
| 83 | +@flax_module |
| 84 | +class DescrptYourName(DescrptYourNameDP): |
| 85 | + def __setattr__(self, name, value): |
| 86 | + if name in {"davg", "dstd"}: |
| 87 | + value = to_jax_array(value) |
| 88 | + if value is not None: |
| 89 | + value = ArrayAPIVariable(value) |
| 90 | + elif name in {"embeddings"}: |
| 91 | + if value is not None: |
| 92 | + value = NetworkCollection.deserialize(value.serialize()) |
| 93 | + elif name == "env_mat": |
| 94 | + pass # stateless |
| 95 | + elif name == "emask": |
| 96 | + value = PairExcludeMask(value.ntypes, value.exclude_types) |
| 97 | + return super().__setattr__(name, value) |
| 98 | +``` |
| 99 | + |
| 100 | +For nested sub-components, define wrapper classes bottom-up. See `deepmd/jax/descriptor/dpa1.py` for example. |
| 101 | + |
| 102 | +**Edit** `deepmd/jax/descriptor/__init__.py` — add import and `__all__` entry. |
| 103 | + |
| 104 | +## Step 4: Wrap for pt_expt backend |
| 105 | + |
| 106 | +**Create** `deepmd/pt_expt/descriptor/<name>.py` |
| 107 | + |
| 108 | +The `@torch_module` decorator handles everything automatically: |
| 109 | + |
| 110 | +- Auto-generates `forward()` delegating to `call()` (and `forward_lower()` from `call_lower()`) |
| 111 | +- Auto-generates `__setattr__` that converts numpy arrays to torch buffers and dpmodel objects to pt_expt modules via a converter registry |
| 112 | +- Any unregistered `NativeOP` assigned as an attribute will raise `TypeError` — register it first |
| 113 | + |
| 114 | +Simple descriptors (no custom sub-components) need only an empty body: |
| 115 | + |
| 116 | +```python |
| 117 | +from deepmd.dpmodel.descriptor.your_name import DescrptYourName as DescrptYourNameDP |
| 118 | +from deepmd.pt_expt.common import torch_module |
| 119 | +from deepmd.pt_expt.descriptor.base_descriptor import BaseDescriptor |
| 120 | + |
| 121 | + |
| 122 | +@BaseDescriptor.register("your_name") |
| 123 | +@torch_module |
| 124 | +class DescrptYourName(DescrptYourNameDP): |
| 125 | + pass |
| 126 | +``` |
| 127 | + |
| 128 | +Standard dpmodel sub-components (`NetworkCollection`, `EmbeddingNet`, `PairExcludeMask`, `EnvMat`, `TypeEmbedNet`) are pre-registered in `deepmd/pt_expt/utils/` and converted automatically. No `__setattr__` override needed. |
| 129 | + |
| 130 | +For **custom sub-components** (e.g., a new block class inheriting `NativeOP`), create a separate wrapper file and register bottom-up with `register_dpmodel_mapping`: |
| 131 | + |
| 132 | +```python |
| 133 | +# deepmd/pt_expt/descriptor/your_block.py |
| 134 | +from deepmd.dpmodel.descriptor.your_block import YourBlock as YourBlockDP |
| 135 | +from deepmd.pt_expt.common import register_dpmodel_mapping, torch_module |
| 136 | + |
| 137 | + |
| 138 | +@torch_module |
| 139 | +class YourBlock(YourBlockDP): |
| 140 | + pass |
| 141 | + |
| 142 | + |
| 143 | +register_dpmodel_mapping( |
| 144 | + YourBlockDP, |
| 145 | + lambda v: YourBlock.deserialize(v.serialize()), |
| 146 | +) |
| 147 | +``` |
| 148 | + |
| 149 | +Then import this module in `deepmd/pt_expt/descriptor/__init__.py` for its side effect (the registration must happen before the parent descriptor is instantiated). |
| 150 | + |
| 151 | +Reference: `deepmd/pt_expt/descriptor/se_t_tebd.py` + `se_t_tebd_block.py` |
| 152 | + |
| 153 | +**Edit** `deepmd/pt_expt/descriptor/__init__.py` — add import and `__all__` entry. |
| 154 | + |
| 155 | +## Step 5: Hard-code for PT backend (if needed) |
| 156 | + |
| 157 | +**Create** `deepmd/pt/model/descriptor/<name>.py` |
| 158 | + |
| 159 | +PT descriptors are fully reimplemented in PyTorch (not wrapping dpmodel). They inherit from `BaseDescriptor` and `torch.nn.Module`. Must implement `forward()`, `serialize()`, `deserialize()`. |
| 160 | + |
| 161 | +**Edit** `deepmd/pt/model/descriptor/__init__.py` — add import. |
| 162 | + |
| 163 | +Reference: `deepmd/pt/model/descriptor/se_a.py` |
| 164 | + |
| 165 | +## Step 6: Hard-code for TF backend (if needed) |
| 166 | + |
| 167 | +**Create** `deepmd/tf/descriptor/<name>.py` |
| 168 | + |
| 169 | +TF descriptors are fully reimplemented in TensorFlow. They inherit from `BaseDescriptor` and implement the TF computation graph. |
| 170 | + |
| 171 | +**Edit** `deepmd/tf/descriptor/__init__.py` — add import. |
| 172 | + |
| 173 | +Reference: `deepmd/tf/descriptor/se_a.py` |
| 174 | + |
| 175 | +## Step 7: Hard-code for PD backend (if needed) |
| 176 | + |
| 177 | +Same as PT but using Paddle. Inherit from `BaseDescriptor` and `paddle.nn.Layer`. |
| 178 | + |
| 179 | +**Edit** `deepmd/pd/model/descriptor/__init__.py` — add import. |
| 180 | + |
| 181 | +Reference: `deepmd/pd/model/descriptor/se_a.py` |
| 182 | + |
| 183 | +## Step 8: Write tests |
| 184 | + |
| 185 | +Eight test categories. See [references/test-patterns.md](references/test-patterns.md) for full code templates. |
| 186 | + |
| 187 | +pt_expt tests use `pytest.mark.parametrize` (not `itertools.product`), do not inherit from `unittest.TestCase`, and use `setup_method` (not `setUp`). |
| 188 | + |
| 189 | +| Test | File | Purpose | |
| 190 | +| --------------------- | -------------------------------------------------------------- | ------------------------------------------------- | |
| 191 | +| 8a. dpmodel | `source/tests/common/dpmodel/test_descriptor_<name>.py` | Serialize/deserialize round-trip | |
| 192 | +| 8b. pt_expt | `source/tests/pt_expt/descriptor/test_<name>.py` | Consistency + exportable + make_fx (float64 only) | |
| 193 | +| 8c. PT | `source/tests/pt/model/test_descriptor_<name>.py` | PT hard-coded tests (if applicable) | |
| 194 | +| 8d. PD | `source/tests/pd/model/test_descriptor_<name>.py` | PD hard-coded tests (if applicable) | |
| 195 | +| 8e. array_api_strict | `source/tests/array_api_strict/descriptor/<name>.py` | Wrapper for consistency tests | |
| 196 | +| 8f. Universal dpmodel | `source/tests/universal/dpmodel/descriptor/test_descriptor.py` | Add parametrized entry | |
| 197 | +| 8g. Universal PT | `source/tests/universal/pt/descriptor/test_descriptor.py` | Add parametrized entry | |
| 198 | +| 8h. Consistency | `source/tests/consistent/descriptor/test_<name>.py` | Cross-backend + API consistency | |
| 199 | + |
| 200 | +## Step 9: Write documentation |
| 201 | + |
| 202 | +**Create** `doc/model/<name>.md` |
| 203 | + |
| 204 | +Each descriptor needs a documentation page in `doc/model/`. Use MyST Markdown format with Sphinx extensions. List supported backends using icon substitutions. |
| 205 | + |
| 206 | +Template: |
| 207 | + |
| 208 | +````markdown |
| 209 | +# Descriptor `"your_name"` {{ pytorch_icon }} {{ dpmodel_icon }} |
| 210 | + |
| 211 | +:::{note} |
| 212 | +**Supported backends**: PyTorch {{ pytorch_icon }}, DP {{ dpmodel_icon }} |
| 213 | +::: |
| 214 | + |
| 215 | +Brief description of what the descriptor is and its theoretical motivation. |
| 216 | + |
| 217 | +## Theory |
| 218 | + |
| 219 | +Mathematical formulation using LaTeX: |
| 220 | + |
| 221 | +```math |
| 222 | + \mathcal{D}^i = ... |
| 223 | +``` |
| 224 | + |
| 225 | +## Instructions |
| 226 | + |
| 227 | +Example JSON configuration: |
| 228 | + |
| 229 | +```json |
| 230 | +"descriptor": { |
| 231 | + "type": "your_name", |
| 232 | + "sel": [46, 92], |
| 233 | + "rcut_smth": 0.50, |
| 234 | + "rcut": 6.00, |
| 235 | + "neuron": [10, 20, 40], |
| 236 | + "resnet_dt": false, |
| 237 | + "seed": 1 |
| 238 | +} |
| 239 | +``` |
| 240 | + |
| 241 | +Explain key parameters and link to the argument schema using `{ref}` directives, |
| 242 | +e.g. `{ref}rcut <model[standard]/descriptor[your_name]/rcut>`. |
| 243 | +```` |
| 244 | + |
| 245 | +Available backend icons: `{{ tensorflow_icon }}`, `{{ pytorch_icon }}`, `{{ jax_icon }}`, `{{ paddle_icon }}`, `{{ dpmodel_icon }}`. Only list backends that actually support this descriptor. |
| 246 | + |
| 247 | +**Edit** `doc/model/index.rst` — add the new page to the `toctree`: |
| 248 | + |
| 249 | +```rst |
| 250 | +.. toctree:: |
| 251 | + :maxdepth: 1 |
| 252 | +
|
| 253 | + ... |
| 254 | + <name> |
| 255 | +``` |
| 256 | + |
| 257 | +**Reference docs**: `doc/model/train-se-e2-r.md` (simple), `doc/model/dpa2.md` (modern) |
| 258 | + |
| 259 | +## Verification |
| 260 | + |
| 261 | +```bash |
| 262 | +# dpmodel self-consistency |
| 263 | +python -m pytest source/tests/common/dpmodel/test_descriptor_<name>.py -v |
| 264 | + |
| 265 | +# pt_expt unit tests |
| 266 | +python -m pytest source/tests/pt_expt/descriptor/test_<name>.py -v |
| 267 | + |
| 268 | +# Cross-backend consistency |
| 269 | +python -m pytest source/tests/consistent/descriptor/test_<name>.py -v |
| 270 | + |
| 271 | +# PT/PD unit tests (if hard-coded) |
| 272 | +python -m pytest source/tests/pt/model/test_descriptor_<name>.py -v |
| 273 | +python -m pytest source/tests/pd/model/test_descriptor_<name>.py -v |
| 274 | + |
| 275 | +# Quick smoke test |
| 276 | +python -c " |
| 277 | +from deepmd.dpmodel.descriptor import DescrptYourName |
| 278 | +d = DescrptYourName(rcut=6.0, rcut_smth=1.8, sel=[20, 20]) |
| 279 | +d2 = DescrptYourName.deserialize(d.serialize()) |
| 280 | +print('Round-trip OK:', d.get_dim_out() == d2.get_dim_out()) |
| 281 | +" |
| 282 | +``` |
| 283 | + |
| 284 | +## Files summary |
| 285 | + |
| 286 | +| Step | Action | File | |
| 287 | +| ---- | ------ | -------------------------------------------------------------- | |
| 288 | +| 1 | Create | `deepmd/dpmodel/descriptor/<name>.py` | |
| 289 | +| 2 | Edit | `deepmd/dpmodel/descriptor/__init__.py` | |
| 290 | +| 2 | Edit | `deepmd/utils/argcheck.py` | |
| 291 | +| 3 | Create | `deepmd/jax/descriptor/<name>.py` | |
| 292 | +| 3 | Edit | `deepmd/jax/descriptor/__init__.py` | |
| 293 | +| 4 | Create | `deepmd/pt_expt/descriptor/<name>.py` | |
| 294 | +| 4 | Edit | `deepmd/pt_expt/descriptor/__init__.py` | |
| 295 | +| 5 | Create | `deepmd/pt/model/descriptor/<name>.py` (if needed) | |
| 296 | +| 5 | Edit | `deepmd/pt/model/descriptor/__init__.py` (if needed) | |
| 297 | +| 6 | Create | `deepmd/tf/descriptor/<name>.py` (if needed) | |
| 298 | +| 6 | Edit | `deepmd/tf/descriptor/__init__.py` (if needed) | |
| 299 | +| 7 | Create | `deepmd/pd/model/descriptor/<name>.py` (if needed) | |
| 300 | +| 7 | Edit | `deepmd/pd/model/descriptor/__init__.py` (if needed) | |
| 301 | +| 8a | Create | `source/tests/common/dpmodel/test_descriptor_<name>.py` | |
| 302 | +| 8b | Create | `source/tests/pt_expt/descriptor/test_<name>.py` | |
| 303 | +| 8c | Create | `source/tests/pt/model/test_descriptor_<name>.py` (if PT) | |
| 304 | +| 8d | Create | `source/tests/pd/model/test_descriptor_<name>.py` (if PD) | |
| 305 | +| 8e | Create | `source/tests/array_api_strict/descriptor/<name>.py` | |
| 306 | +| 8e | Edit | `source/tests/array_api_strict/descriptor/__init__.py` | |
| 307 | +| 8f | Edit | `source/tests/universal/dpmodel/descriptor/test_descriptor.py` | |
| 308 | +| 8g | Edit | `source/tests/universal/pt/descriptor/test_descriptor.py` | |
| 309 | +| 8h | Create | `source/tests/consistent/descriptor/test_<name>.py` | |
| 310 | +| 9 | Create | `doc/model/<name>.md` | |
| 311 | +| 9 | Edit | `doc/model/index.rst` | |
0 commit comments