Skip to content

Commit 0918b22

Browse files
iProzdnjzjz
andauthored
doc: update DPA3 doc and example (#4655)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **Documentation** - Updated documentation to correctly reference a DPA-2 example. - Introduced new documentation for the advanced DPA-3 model outlining its capabilities, training benchmarks, and installation requirements. - Expanded the documentation index and spin configuration sections to include DPA-3. - **New Features** - Added a README with configuration details for training a 6-layer DPA-3 model. - Provided a comprehensive JSON configuration file with training parameters. - Updated simulation instructions to support both DPA-2 and DPA-3. - **Tests** - Extended testing to cover DPA-3 configurations. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Duo <50307526+iProzd@users.noreply.github.com> Co-authored-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
1 parent 00b6d55 commit 0918b22

8 files changed

Lines changed: 179 additions & 2 deletions

File tree

doc/model/dpa2.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ When using the JAX backend, 2 or more MPI ranks are not supported. One must set
2626
atom_modify map yes
2727
```
2828

29-
See the example `examples/water/lmp/jax_dpa2.lammps`.
29+
See the example `examples/water/lmp/jax_dpa.lammps`.
3030

3131
## Data format
3232

doc/model/dpa3.md

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Descriptor DPA-3 {{ pytorch_icon }} {{ jax_icon }} {{ dpmodel_icon }}
2+
3+
:::{note}
4+
**Supported backends**: PyTorch {{ pytorch_icon }}, JAX {{ jax_icon }}, DP {{ dpmodel_icon }}
5+
:::
6+
7+
DPA-3 is an advanced interatomic potential leveraging the message passing architecture.
8+
Designed as a large atomic model (LAM), DPA-3 is tailored to integrate and simultaneously train on datasets from various disciplines,
9+
encompassing diverse chemical and materials systems across different research domains.
10+
Its model design ensures exceptional fitting accuracy and robust generalization both within and beyond the training domain.
11+
Furthermore, DPA-3 maintains energy conservation and respects the physical symmetries of the potential energy surface,
12+
making it a dependable tool for a wide range of scientific applications.
13+
14+
Reference: will be released soon.
15+
16+
Training example: `examples/water/dpa3/input_torch.json`.
17+
18+
## Hyperparameter tests
19+
20+
We systematically conducted DPA-3 training on six representative DFT datasets (available at [AIS-Square](https://www.aissquare.com/datasets/detail?pageType=datasets&name=DPA3_hyperparameter_search&id=316)):
21+
metallic systems (`Alloy`, `AlMgCu`, `W`), covalent material (`Boron`), molecular system (`Drug`), and liquid water (`Water`).
22+
Under consistent training conditions (0.5M training steps, batch_size "auto:128"),
23+
we rigorously evaluated the impacts of some critical hyperparameters on validation accuracy.
24+
25+
The comparative analysis focused on average RMSEs (Root Mean Square Error) for both energy, force and virial predictions across all six systems,
26+
with results tabulated below to guide scenario-specific hyperparameter selection:
27+
28+
| Model | comment | nlayers | n_dim | e_dim | a_dim | e_sel | a_sel | start_lr | stop_lr | loss prefactors | rmse_e (meV/atom) | rmse_f (meV/Å) | rmse_v (meV/atom) | Training wall time (h) |
29+
| ---------------- | --------------- | ------- | ------- | ------ | ----- | ------- | ------ | -------- | -------- | ------------------------- | ----------------- | -------------- | ----------------- | ---------------------- |
30+
| DPA3-L3 | Default | 3 | 256 | 128 | 32 | 120 | 30 | 1e-3 | 3e-5 | 0.2\|20, 100\|60, 0.02\|1 | 5.74 | 85.4 | 43.1 | 9.8 |
31+
| | Small dimension | 3 | **128** | **64** | 32 | 120 | 30 | 1e-3 | 3e-5 | 0.2\|20, 100\|60, 0.02\|1 | 6.99 | 93.6 | 46.7 | 8.0 |
32+
| | Large sel | 3 | 256 | 128 | 32 | **154** | **48** | 1e-3 | 3e-5 | 0.2\|20, 100\|60, 0.02\|1 | 5.70 | 83.7 | 43.4 | 14.1 |
33+
| DPA3-L6 | Default | 6 | 256 | 128 | 32 | 120 | 30 | 1e-3 | 3e-5 | 0.2\|20, 100\|60, 0.02\|1 | 4.85 | 79.9 | 39.7 | 19.2 |
34+
| | Small dimension | 6 | **128** | **64** | 32 | 120 | 30 | 1e-3 | 3e-5 | 0.2\|20, 100\|60, 0.02\|1 | 5.11 | 77.7 | 41.2 | 14.1 |
35+
| | Large sel | 6 | 256 | 128 | 32 | **154** | **48** | 1e-3 | 3e-5 | 0.2\|20, 100\|60, 0.02\|1 | 4.76 | 78.4 | 40.2 | 31.8 |
36+
| DPA2-L6 (medium) | Default | 6 | - | - | - | - | - | 1e-3 | 3.51e-08 | 0.02\|1, 1000\|1, 0.02\|1 | 12.12 | 109.3 | 83.1 | 12.2 |
37+
38+
The loss prefactors (0.2|20, 100|60, 0.02|1) correspond to (`start_pref_e`|`limit_pref_e`, `start_pref_f`|`limit_pref_f`, `start_pref_v`|`limit_pref_v`) respectively.
39+
Virial RMSEs were averaged exclusively for systems containing virial labels (`Alloy`, `AlMgCu`, `W`, and `Boron`).
40+
41+
Note that we set `float32` in all DPA-3 models, while `float64` in other models by default.
42+
43+
## Requirements of installation from source code {{ pytorch_icon }}
44+
45+
To run the DPA-3 model on LAMMPS via source code installation
46+
(users can skip this step if using [easy installation](../install/easy-install.md)),
47+
the custom OP library for Python interface integration must be compiled and linked
48+
during the [model freezing process](../freeze/freeze.md).
49+
50+
The customized OP library for the Python interface can be installed by setting environment variable {envvar}`DP_ENABLE_PYTORCH` to `1` during installation.
51+
52+
If one runs LAMMPS with MPI, the customized OP library for the C++ interface should be compiled against the same MPI library as the runtime MPI.
53+
If one runs LAMMPS with MPI and CUDA devices, it is recommended to compile the customized OP library for the C++ interface with a [CUDA-Aware MPI](https://developer.nvidia.com/mpi-solutions-gpus) library and CUDA,
54+
otherwise the communication between GPU cards falls back to the slower CPU implementation.
55+
56+
## Limitations of the JAX backend with LAMMPS {{ jax_icon }}
57+
58+
When using the JAX backend, 2 or more MPI ranks are not supported. One must set `map` to `yes` using the [`atom_modify`](https://docs.lammps.org/atom_modify.html) command.
59+
60+
```lammps
61+
atom_modify map yes
62+
```
63+
64+
See the example `examples/water/lmp/jax_dpa.lammps`.
65+
66+
## Data format
67+
68+
DPA-3 supports both the [standard data format](../data/system.md) and the [mixed type data format](../data/system.md#mixed-type).
69+
70+
## Type embedding
71+
72+
Type embedding is within this descriptor with the same dimension as the node embedding: {ref}`n_dim <model[standard]/descriptor[dpa3]/repflow/n_dim>` argument.
73+
74+
## Model compression
75+
76+
Model compression is not supported in this descriptor.

doc/model/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ Model
1010
train-se-e3
1111
train-se-atten
1212
dpa2
13+
dpa3
1314
train-hybrid
1415
sel
1516
train-energy

doc/model/train-energy-spin.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ In PyTorch/DP, the spin implementation is more flexible and so far supports the
5151
- `se_e2_a`
5252
- `dpa1`(`se_atten`)
5353
- `dpa2`
54+
- `dpa3`
5455

5556
See `se_e2_a` examples in `$deepmd_source_dir/examples/spin/se_e2_a/input_torch.json`, the {ref}`spin <model/spin>` section is defined as the following with a much more clear interface:
5657

examples/water/dpa3/README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# Input for the DPA-3 model
2+
3+
This directory stores configuration files for training the 6-layer DPA-3 model.
4+
For comprehensive hyperparameter selection, consult the [DPA-3 documentation](../../../doc/model/dpa3.md/#hyperparameter-tests).
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
{
2+
"_comment": "that's all",
3+
"model": {
4+
"type_map": [
5+
"O",
6+
"H"
7+
],
8+
"descriptor": {
9+
"type": "dpa3",
10+
"repflow": {
11+
"n_dim": 256,
12+
"e_dim": 128,
13+
"a_dim": 32,
14+
"nlayers": 6,
15+
"e_rcut": 6.0,
16+
"e_rcut_smth": 3.0,
17+
"e_sel": 120,
18+
"a_rcut": 4.0,
19+
"a_rcut_smth": 2.0,
20+
"a_sel": 30,
21+
"axis_neuron": 4,
22+
"skip_stat": true,
23+
"a_compress_rate": 1,
24+
"a_compress_e_rate": 2,
25+
"a_compress_use_split": true,
26+
"update_angle": true,
27+
"update_style": "res_residual",
28+
"update_residual": 0.1,
29+
"update_residual_init": "const"
30+
},
31+
"activation_function": "silut:10.0",
32+
"use_tebd_bias": false,
33+
"precision": "float32",
34+
"concat_output_tebd": false
35+
},
36+
"fitting_net": {
37+
"neuron": [
38+
240,
39+
240,
40+
240
41+
],
42+
"resnet_dt": true,
43+
"precision": "float32",
44+
"activation_function": "silut:10.0",
45+
"seed": 1,
46+
"_comment": " that's all"
47+
},
48+
"_comment": " that's all"
49+
},
50+
"learning_rate": {
51+
"type": "exp",
52+
"decay_steps": 5000,
53+
"start_lr": 0.001,
54+
"stop_lr": 3e-5,
55+
"_comment": "that's all"
56+
},
57+
"loss": {
58+
"type": "ener",
59+
"start_pref_e": 0.2,
60+
"limit_pref_e": 20,
61+
"start_pref_f": 100,
62+
"limit_pref_f": 60,
63+
"start_pref_v": 0.02,
64+
"limit_pref_v": 1,
65+
"_comment": " that's all"
66+
},
67+
"training": {
68+
"stat_file": "./dpa3.hdf5",
69+
"training_data": {
70+
"systems": [
71+
"../data/data_0",
72+
"../data/data_1",
73+
"../data/data_2"
74+
],
75+
"batch_size": 1,
76+
"_comment": "that's all"
77+
},
78+
"validation_data": {
79+
"systems": [
80+
"../data/data_3"
81+
],
82+
"batch_size": 1,
83+
"_comment": "that's all"
84+
},
85+
"numb_steps": 1000000,
86+
"warmup_steps": 0,
87+
"gradient_max_norm": 5.0,
88+
"seed": 10,
89+
"disp_file": "lcurve.out",
90+
"disp_freq": 100,
91+
"save_freq": 2000,
92+
"_comment": "that's all"
93+
}
94+
}
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
units metal
66
boundary p p p
77
atom_style atomic
8-
# Below line is required when using DPA-2 with the JAX backend
8+
# Below line is required when using DPA-2/3 with the JAX backend
99
atom_modify map yes
1010

1111
neighbor 2.0 bin

source/tests/common/test_examples.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
p_examples / "water" / "dpa2" / "input_torch_medium.json",
5959
p_examples / "water" / "dpa2" / "input_torch_large.json",
6060
p_examples / "water" / "dpa2" / "input_torch_compressible.json",
61+
p_examples / "water" / "dpa3" / "input_torch.json",
6162
p_examples / "property" / "train" / "input_torch.json",
6263
p_examples / "water" / "se_e3_tebd" / "input_torch.json",
6364
p_examples / "hessian" / "single_task" / "input.json",

0 commit comments

Comments
 (0)